mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-22 08:56:54 +02:00
Compare commits
63 commits
Author | SHA1 | Date | |
---|---|---|---|
|
27e661d68e | ||
|
71be0eb3a6 | ||
|
e3f1273f8a | ||
|
c97b5b7615 | ||
|
668ddfd455 | ||
|
b8da08c106 | ||
|
2e3f7d122c | ||
|
2e7780471a | ||
|
87d8c00f86 | ||
|
c00bda9200 | ||
|
d2b0fc9789 | ||
|
77d39ff3b9 | ||
|
e433d13df6 | ||
|
3ddf952973 | ||
|
3f0a3bcfa0 | ||
|
4dddf62e57 | ||
|
827ec6e14b | ||
|
92e28a0d14 | ||
|
52fed4d362 | ||
|
9c6b3ff332 | ||
|
7de7a9a754 | ||
|
0c347529b8 | ||
|
6705978fc8 | ||
|
032e33f577 | ||
|
59101fd202 | ||
|
8bcfbac230 | ||
|
f0dfb5eacc | ||
|
9195025d8f | ||
|
cbd414dfec | ||
|
7155d20913 | ||
|
bfeb3954f6 | ||
|
e3c9ec8012 | ||
|
ce9d3866a3 | ||
|
e5f355e843 | ||
|
c05b2ee2a3 | ||
|
180c9284f3 | ||
|
015e11875d | ||
|
12269c2761 | ||
|
542e565baa | ||
|
7c20311b3d | ||
|
4ffa9c2032 | ||
|
d0bc03c707 | ||
|
1cf89f5339 | ||
|
b43118018e | ||
|
7af55a3e6f | ||
|
c493b95f66 | ||
|
2e0774f246 | ||
|
b3df23dcd4 | ||
|
f502ec3fad | ||
|
5d37bd24e1 | ||
|
24ea13351e | ||
|
177caa7e44 | ||
|
b81ca925db | ||
|
42ec952ead | ||
|
ec8f6f82c2 | ||
|
1ec454f253 | ||
|
8a015f7c76 | ||
|
895d6c23cd | ||
|
4201e08f1d | ||
|
6a84778f2c | ||
|
b34974c476 | ||
|
f30419e0d1 | ||
|
8f1a6a10b2 |
64 changed files with 3473 additions and 1847 deletions
41
.github/workflows/build-if-tag.yml
vendored
Normal file
41
.github/workflows/build-if-tag.yml
vendored
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
name: build-if-tag
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v[0-9]+.[0-9]+.[0-9]+'
|
||||||
|
|
||||||
|
env:
|
||||||
|
APP: amneziawg-go
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
name: build
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.ref_name }}
|
||||||
|
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Setup metadata
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
id: metadata
|
||||||
|
with:
|
||||||
|
images: amneziavpn/${{ env.APP }}
|
||||||
|
tags: type=semver,pattern={{version}}
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.metadata.outputs.tags }}
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1 +1 @@
|
||||||
wireguard-go
|
amneziawg-go
|
17
Dockerfile
Normal file
17
Dockerfile
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
FROM golang:1.24 as awg
|
||||||
|
COPY . /awg
|
||||||
|
WORKDIR /awg
|
||||||
|
RUN go mod download && \
|
||||||
|
go mod verify && \
|
||||||
|
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
|
||||||
|
|
||||||
|
FROM alpine:3.19
|
||||||
|
ARG AWGTOOLS_RELEASE="1.0.20241018"
|
||||||
|
RUN apk --no-cache add iproute2 iptables bash && \
|
||||||
|
cd /usr/bin/ && \
|
||||||
|
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
|
||||||
|
unzip -j alpine-3.19-amneziawg-tools.zip && \
|
||||||
|
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
|
||||||
|
ln -s /usr/bin/awg /usr/bin/wg && \
|
||||||
|
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
|
||||||
|
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go
|
12
Makefile
12
Makefile
|
@ -9,23 +9,23 @@ MAKEFLAGS += --no-print-directory
|
||||||
|
|
||||||
generate-version-and-build:
|
generate-version-and-build:
|
||||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
|
||||||
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
|
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
|
||||||
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
||||||
echo "$$ver" > version.go && \
|
echo "$$ver" > version.go && \
|
||||||
git update-index --assume-unchanged version.go || true
|
git update-index --assume-unchanged version.go || true
|
||||||
@$(MAKE) wireguard-go
|
@$(MAKE) amneziawg-go
|
||||||
|
|
||||||
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
|
||||||
go build -v -o "$@"
|
go build -v -o "$@"
|
||||||
|
|
||||||
install: wireguard-go
|
install: amneziawg-go
|
||||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test ./...
|
go test ./...
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -f wireguard-go
|
rm -f amneziawg-go
|
||||||
|
|
||||||
.PHONY: all clean test install generate-version-and-build
|
.PHONY: all clean test install generate-version-and-build
|
||||||
|
|
59
README.md
59
README.md
|
@ -1,24 +1,27 @@
|
||||||
# Go Implementation of [WireGuard](https://www.wireguard.com/)
|
# Go Implementation of AmneziaWG
|
||||||
|
|
||||||
This is an implementation of WireGuard in Go.
|
AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original.
|
||||||
|
|
||||||
|
The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
|
||||||
|
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
|
||||||
|
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
|
Simply run:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ wireguard-go wg0
|
$ amneziawg-go wg0
|
||||||
```
|
```
|
||||||
|
|
||||||
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
|
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
|
||||||
|
|
||||||
To run wireguard-go without forking to the background, pass `-f` or `--foreground`:
|
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ wireguard-go -f wg0
|
$ amneziawg-go -f wg0
|
||||||
```
|
```
|
||||||
|
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||||
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
|
||||||
|
|
||||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||||
|
|
||||||
|
@ -26,52 +29,24 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||||
|
|
||||||
### Linux
|
### Linux
|
||||||
|
|
||||||
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
|
This will run on Linux; you should run amnezia-wg instead of using default linux kernel module.
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
||||||
|
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
|
||||||
|
|
||||||
### Windows
|
### Windows
|
||||||
|
|
||||||
This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
|
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
|
||||||
|
|
||||||
### FreeBSD
|
|
||||||
|
|
||||||
This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
|
|
||||||
|
|
||||||
### OpenBSD
|
|
||||||
|
|
||||||
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
|
||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
||||||
This requires an installation of the latest version of [Go](https://go.dev/).
|
This requires an installation of the latest version of [Go](https://go.dev/).
|
||||||
|
|
||||||
```
|
```
|
||||||
$ git clone https://git.zx2c4.com/wireguard-go
|
$ git clone https://github.com/amnezia-vpn/amneziawg-go
|
||||||
$ cd wireguard-go
|
$ cd amneziawg-go
|
||||||
$ make
|
$ make
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
||||||
this software and associated documentation files (the "Software"), to deal in
|
|
||||||
the Software without restriction, including without limitation the rights to
|
|
||||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
|
||||||
of the Software, and to permit persons to whom the Software is furnished to do
|
|
||||||
so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
406
conn/bind_std.go
406
conn/bind_std.go
|
@ -8,6 +8,7 @@ package conn
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -29,16 +30,19 @@ var (
|
||||||
// methods for sending and receiving multiple datagrams per-syscall. See the
|
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||||
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||||
type StdNetBind struct {
|
type StdNetBind struct {
|
||||||
mu sync.Mutex // protects all fields except as specified
|
mu sync.Mutex // protects all fields except as specified
|
||||||
ipv4 *net.UDPConn
|
ipv4 *net.UDPConn
|
||||||
ipv6 *net.UDPConn
|
ipv6 *net.UDPConn
|
||||||
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||||
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||||
|
ipv4TxOffload bool
|
||||||
|
ipv4RxOffload bool
|
||||||
|
ipv6TxOffload bool
|
||||||
|
ipv6RxOffload bool
|
||||||
|
|
||||||
// these three fields are not guarded by mu
|
// these two fields are not guarded by mu
|
||||||
udpAddrPool sync.Pool
|
udpAddrPool sync.Pool
|
||||||
ipv4MsgsPool sync.Pool
|
msgsPool sync.Pool
|
||||||
ipv6MsgsPool sync.Pool
|
|
||||||
|
|
||||||
blackhole4 bool
|
blackhole4 bool
|
||||||
blackhole6 bool
|
blackhole6 bool
|
||||||
|
@ -54,23 +58,14 @@ func NewStdNetBind() Bind {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
ipv4MsgsPool: sync.Pool{
|
msgsPool: sync.Pool{
|
||||||
New: func() any {
|
|
||||||
msgs := make([]ipv4.Message, IdealBatchSize)
|
|
||||||
for i := range msgs {
|
|
||||||
msgs[i].Buffers = make(net.Buffers, 1)
|
|
||||||
msgs[i].OOB = make([]byte, srcControlSize)
|
|
||||||
}
|
|
||||||
return &msgs
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
ipv6MsgsPool: sync.Pool{
|
|
||||||
New: func() any {
|
New: func() any {
|
||||||
|
// ipv6.Message and ipv4.Message are interchangeable as they are
|
||||||
|
// both aliases for x/net/internal/socket.Message.
|
||||||
msgs := make([]ipv6.Message, IdealBatchSize)
|
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
msgs[i].Buffers = make(net.Buffers, 1)
|
msgs[i].Buffers = make(net.Buffers, 1)
|
||||||
msgs[i].OOB = make([]byte, srcControlSize)
|
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
|
||||||
}
|
}
|
||||||
return &msgs
|
return &msgs
|
||||||
},
|
},
|
||||||
|
@ -113,7 +108,7 @@ func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||||
return e.AddrPort.Addr()
|
return e.AddrPort.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||||
b, _ := e.AddrPort.MarshalBinary()
|
b, _ := e.AddrPort.MarshalBinary()
|
||||||
|
@ -179,19 +174,21 @@ again:
|
||||||
}
|
}
|
||||||
var fns []ReceiveFunc
|
var fns []ReceiveFunc
|
||||||
if v4conn != nil {
|
if v4conn != nil {
|
||||||
if runtime.GOOS == "linux" {
|
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
v4pc = ipv4.NewPacketConn(v4conn)
|
v4pc = ipv4.NewPacketConn(v4conn)
|
||||||
s.ipv4PC = v4pc
|
s.ipv4PC = v4pc
|
||||||
}
|
}
|
||||||
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
|
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
||||||
s.ipv4 = v4conn
|
s.ipv4 = v4conn
|
||||||
}
|
}
|
||||||
if v6conn != nil {
|
if v6conn != nil {
|
||||||
if runtime.GOOS == "linux" {
|
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
v6pc = ipv6.NewPacketConn(v6conn)
|
v6pc = ipv6.NewPacketConn(v6conn)
|
||||||
s.ipv6PC = v6pc
|
s.ipv6PC = v6pc
|
||||||
}
|
}
|
||||||
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
|
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
||||||
s.ipv6 = v6conn
|
s.ipv6 = v6conn
|
||||||
}
|
}
|
||||||
if len(fns) == 0 {
|
if len(fns) == 0 {
|
||||||
|
@ -201,76 +198,101 @@ again:
|
||||||
return fns, uint16(port), nil
|
return fns, uint16(port), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
||||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
for i := range *msgs {
|
||||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||||
defer s.ipv4MsgsPool.Put(msgs)
|
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||||
for i := range bufs {
|
}
|
||||||
(*msgs)[i].Buffers[0] = bufs[i]
|
s.msgsPool.Put(msgs)
|
||||||
}
|
}
|
||||||
var numMsgs int
|
|
||||||
if runtime.GOOS == "linux" {
|
func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
||||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
return s.msgsPool.Get().(*[]ipv6.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// If compilation fails here these are no longer the same underlying type.
|
||||||
|
_ ipv6.Message = ipv4.Message{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type batchReader interface {
|
||||||
|
ReadBatch([]ipv6.Message, int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type batchWriter interface {
|
||||||
|
WriteBatch([]ipv6.Message, int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) receiveIP(
|
||||||
|
br batchReader,
|
||||||
|
conn *net.UDPConn,
|
||||||
|
rxOffload bool,
|
||||||
|
bufs [][]byte,
|
||||||
|
sizes []int,
|
||||||
|
eps []Endpoint,
|
||||||
|
) (n int, err error) {
|
||||||
|
msgs := s.getMessages()
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||||
|
}
|
||||||
|
defer s.putMessages(msgs)
|
||||||
|
var numMsgs int
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
if rxOffload {
|
||||||
|
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
|
||||||
|
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
msg := &(*msgs)[0]
|
numMsgs, err = br.ReadBatch(*msgs, 0)
|
||||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
numMsgs = 1
|
|
||||||
}
|
}
|
||||||
for i := 0; i < numMsgs; i++ {
|
} else {
|
||||||
msg := &(*msgs)[i]
|
msg := &(*msgs)[0]
|
||||||
sizes[i] = msg.N
|
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
if err != nil {
|
||||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
return 0, err
|
||||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
||||||
eps[i] = ep
|
|
||||||
}
|
}
|
||||||
return numMsgs, nil
|
numMsgs = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
msg := &(*msgs)[i]
|
||||||
|
sizes[i] = msg.N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
|
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
|
eps[i] = ep
|
||||||
|
}
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
|
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
|
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||||
defer s.ipv6MsgsPool.Put(msgs)
|
|
||||||
for i := range bufs {
|
|
||||||
(*msgs)[i].Buffers[0] = bufs[i]
|
|
||||||
}
|
|
||||||
var numMsgs int
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg := &(*msgs)[0]
|
|
||||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
numMsgs = 1
|
|
||||||
}
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
|
||||||
msg := &(*msgs)[i]
|
|
||||||
sizes[i] = msg.N
|
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
||||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
||||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
||||||
eps[i] = ep
|
|
||||||
}
|
|
||||||
return numMsgs, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||||
// rename the IdealBatchSize constant to BatchSize.
|
// rename the IdealBatchSize constant to BatchSize.
|
||||||
func (s *StdNetBind) BatchSize() int {
|
func (s *StdNetBind) BatchSize() int {
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
return IdealBatchSize
|
return IdealBatchSize
|
||||||
}
|
}
|
||||||
return 1
|
return 1
|
||||||
|
@ -293,28 +315,42 @@ func (s *StdNetBind) Close() error {
|
||||||
}
|
}
|
||||||
s.blackhole4 = false
|
s.blackhole4 = false
|
||||||
s.blackhole6 = false
|
s.blackhole6 = false
|
||||||
|
s.ipv4TxOffload = false
|
||||||
|
s.ipv4RxOffload = false
|
||||||
|
s.ipv6TxOffload = false
|
||||||
|
s.ipv6RxOffload = false
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ErrUDPGSODisabled struct {
|
||||||
|
onLaddr string
|
||||||
|
RetryErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrUDPGSODisabled) Error() string {
|
||||||
|
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrUDPGSODisabled) Unwrap() error {
|
||||||
|
return e.RetryErr
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
blackhole := s.blackhole4
|
blackhole := s.blackhole4
|
||||||
conn := s.ipv4
|
conn := s.ipv4
|
||||||
var (
|
offload := s.ipv4TxOffload
|
||||||
pc4 *ipv4.PacketConn
|
br := batchWriter(s.ipv4PC)
|
||||||
pc6 *ipv6.PacketConn
|
|
||||||
)
|
|
||||||
is6 := false
|
is6 := false
|
||||||
if endpoint.DstIP().Is6() {
|
if endpoint.DstIP().Is6() {
|
||||||
blackhole = s.blackhole6
|
blackhole = s.blackhole6
|
||||||
conn = s.ipv6
|
conn = s.ipv6
|
||||||
pc6 = s.ipv6PC
|
br = s.ipv6PC
|
||||||
is6 = true
|
is6 = true
|
||||||
} else {
|
offload = s.ipv6TxOffload
|
||||||
pc4 = s.ipv4PC
|
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
@ -324,85 +360,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
msgs := s.getMessages()
|
||||||
|
defer s.putMessages(msgs)
|
||||||
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||||
|
defer s.udpAddrPool.Put(ua)
|
||||||
if is6 {
|
if is6 {
|
||||||
return s.send6(conn, pc6, endpoint, bufs)
|
as16 := endpoint.DstIP().As16()
|
||||||
|
copy(ua.IP, as16[:])
|
||||||
|
ua.IP = ua.IP[:16]
|
||||||
} else {
|
} else {
|
||||||
return s.send4(conn, pc4, endpoint, bufs)
|
as4 := endpoint.DstIP().As4()
|
||||||
|
copy(ua.IP, as4[:])
|
||||||
|
ua.IP = ua.IP[:4]
|
||||||
}
|
}
|
||||||
|
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
||||||
|
var (
|
||||||
|
retried bool
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
retry:
|
||||||
|
if offload {
|
||||||
|
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
||||||
|
err = s.send(conn, br, (*msgs)[:n])
|
||||||
|
if err != nil && offload && errShouldDisableUDPGSO(err) {
|
||||||
|
offload = false
|
||||||
|
s.mu.Lock()
|
||||||
|
if is6 {
|
||||||
|
s.ipv6TxOffload = false
|
||||||
|
} else {
|
||||||
|
s.ipv4TxOffload = false
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Addr = ua
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
||||||
|
}
|
||||||
|
err = s.send(conn, br, (*msgs)[:len(bufs)])
|
||||||
|
}
|
||||||
|
if retried {
|
||||||
|
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
|
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
||||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
||||||
as4 := ep.DstIP().As4()
|
|
||||||
copy(ua.IP, as4[:])
|
|
||||||
ua.IP = ua.IP[:4]
|
|
||||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
|
||||||
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
|
||||||
for i, buf := range bufs {
|
|
||||||
(*msgs)[i].Buffers[0] = buf
|
|
||||||
(*msgs)[i].Addr = ua
|
|
||||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
|
||||||
}
|
|
||||||
var (
|
var (
|
||||||
n int
|
n int
|
||||||
err error
|
err error
|
||||||
start int
|
start int
|
||||||
)
|
)
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
for {
|
for {
|
||||||
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
n, err = pc.WriteBatch(msgs[start:], 0)
|
||||||
if err != nil || n == len((*msgs)[start:len(bufs)]) {
|
if err != nil || n == len(msgs[start:]) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
start += n
|
start += n
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for i, buf := range bufs {
|
for _, msg := range msgs {
|
||||||
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.udpAddrPool.Put(ua)
|
|
||||||
s.ipv4MsgsPool.Put(msgs)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
|
const (
|
||||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
// Exceeding these values results in EMSGSIZE. They account for layer3 and
|
||||||
as16 := ep.DstIP().As16()
|
// layer4 headers. IPv6 does not need to account for itself as the payload
|
||||||
copy(ua.IP, as16[:])
|
// length field is self excluding.
|
||||||
ua.IP = ua.IP[:16]
|
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||||
ua.Port = int(ep.(*StdNetEndpoint).Port())
|
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||||
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
|
||||||
for i, buf := range bufs {
|
// This is a hard limit imposed by the kernel.
|
||||||
(*msgs)[i].Buffers[0] = buf
|
udpSegmentMaxDatagrams = 64
|
||||||
(*msgs)[i].Addr = ua
|
)
|
||||||
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
|
|
||||||
}
|
type setGSOFunc func(control *[]byte, gsoSize uint16)
|
||||||
|
|
||||||
|
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
||||||
var (
|
var (
|
||||||
n int
|
base = -1 // index of msg we are currently coalescing into
|
||||||
err error
|
gsoSize int // segmentation size of msgs[base]
|
||||||
start int
|
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||||
|
endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
||||||
)
|
)
|
||||||
if runtime.GOOS == "linux" {
|
maxPayloadLen := maxIPv4PayloadLen
|
||||||
for {
|
if ep.DstIP().Is6() {
|
||||||
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
|
maxPayloadLen = maxIPv6PayloadLen
|
||||||
if err != nil || n == len((*msgs)[start:len(bufs)]) {
|
}
|
||||||
break
|
for i, buf := range bufs {
|
||||||
|
if i > 0 {
|
||||||
|
msgLen := len(buf)
|
||||||
|
baseLenBefore := len(msgs[base].Buffers[0])
|
||||||
|
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||||
|
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||||
|
msgLen <= gsoSize &&
|
||||||
|
msgLen <= freeBaseCap &&
|
||||||
|
dgramCnt < udpSegmentMaxDatagrams &&
|
||||||
|
!endBatch {
|
||||||
|
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
||||||
|
if i == len(bufs)-1 {
|
||||||
|
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||||
|
}
|
||||||
|
dgramCnt++
|
||||||
|
if msgLen < gsoSize {
|
||||||
|
// A smaller than gsoSize packet on the tail is legal, but
|
||||||
|
// it must end the batch.
|
||||||
|
endBatch = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
start += n
|
|
||||||
}
|
}
|
||||||
} else {
|
if dgramCnt > 1 {
|
||||||
for i, buf := range bufs {
|
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||||
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
|
}
|
||||||
if err != nil {
|
// Reset prior to incrementing base since we are preparing to start a
|
||||||
break
|
// new potential batch.
|
||||||
|
endBatch = false
|
||||||
|
base++
|
||||||
|
gsoSize = len(buf)
|
||||||
|
setSrcControl(&msgs[base].OOB, ep)
|
||||||
|
msgs[base].Buffers[0] = buf
|
||||||
|
msgs[base].Addr = addr
|
||||||
|
dgramCnt = 1
|
||||||
|
}
|
||||||
|
return base + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
type getGSOFunc func(control []byte) (int, error)
|
||||||
|
|
||||||
|
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
||||||
|
for i := firstMsgAt; i < len(msgs); i++ {
|
||||||
|
msg := &msgs[i]
|
||||||
|
if msg.N == 0 {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
gsoSize int
|
||||||
|
start int
|
||||||
|
end = msg.N
|
||||||
|
numToSplit = 1
|
||||||
|
)
|
||||||
|
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if gsoSize > 0 {
|
||||||
|
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||||
|
end = gsoSize
|
||||||
|
}
|
||||||
|
for j := 0; j < numToSplit; j++ {
|
||||||
|
if n > i {
|
||||||
|
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||||
}
|
}
|
||||||
|
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||||
|
msgs[n].N = copied
|
||||||
|
msgs[n].Addr = msg.Addr
|
||||||
|
start = end
|
||||||
|
end += gsoSize
|
||||||
|
if end > msg.N {
|
||||||
|
end = msg.N
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
if i != n-1 {
|
||||||
|
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||||
|
// of splitting, so we only zero the source msg len when it is not
|
||||||
|
// the destination of the last split operation above.
|
||||||
|
msg.N = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.udpAddrPool.Put(ua)
|
return n, nil
|
||||||
s.ipv6MsgsPool.Put(msgs)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
package conn
|
package conn
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
||||||
bind := NewStdNetBind().(*StdNetBind)
|
bind := NewStdNetBind().(*StdNetBind)
|
||||||
|
@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
||||||
fn(bufs, sizes, eps)
|
fn(bufs, sizes, eps)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
binary.LittleEndian.PutUint16(*control, gsoSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_coalesceMessages(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
buffs [][]byte
|
||||||
|
wantLens []int
|
||||||
|
wantGSO []int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "one message no coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{1},
|
||||||
|
wantGSO: []int{0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two messages equal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 1, 2),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{2},
|
||||||
|
wantGSO: []int{1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two messages unequal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 3),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{3},
|
||||||
|
wantGSO: []int{2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three messages second unequal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 3),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
},
|
||||||
|
wantLens: []int{3, 2},
|
||||||
|
wantGSO: []int{2, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three messages limited cap coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 4),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
},
|
||||||
|
wantLens: []int{4, 2},
|
||||||
|
wantGSO: []int{2, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1").To4(),
|
||||||
|
Port: 1,
|
||||||
|
}
|
||||||
|
msgs := make([]ipv6.Message, len(tt.buffs))
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make([][]byte, 1)
|
||||||
|
msgs[i].OOB = make([]byte, 0, 2)
|
||||||
|
}
|
||||||
|
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
|
||||||
|
if got != len(tt.wantLens) {
|
||||||
|
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
|
||||||
|
}
|
||||||
|
for i := 0; i < got; i++ {
|
||||||
|
if msgs[i].Addr != addr {
|
||||||
|
t.Errorf("msgs[%d].Addr != passed addr", i)
|
||||||
|
}
|
||||||
|
gotLen := len(msgs[i].Buffers[0])
|
||||||
|
if gotLen != tt.wantLens[i] {
|
||||||
|
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
|
||||||
|
}
|
||||||
|
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
|
||||||
|
}
|
||||||
|
if gotGSO != tt.wantGSO[i] {
|
||||||
|
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockGetGSOSize(control []byte) (int, error) {
|
||||||
|
if len(control) < 2 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return int(binary.LittleEndian.Uint16(control)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_splitCoalescedMessages(t *testing.T) {
|
||||||
|
newMsg := func(n, gso int) ipv6.Message {
|
||||||
|
msg := ipv6.Message{
|
||||||
|
Buffers: [][]byte{make([]byte, 1<<16-1)},
|
||||||
|
N: n,
|
||||||
|
OOB: make([]byte, 2),
|
||||||
|
}
|
||||||
|
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
|
||||||
|
if gso > 0 {
|
||||||
|
msg.NN = 2
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
msgs []ipv6.Message
|
||||||
|
firstMsgAt int
|
||||||
|
wantNumEval int
|
||||||
|
wantMsgLens []int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "second last split last empty",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(3, 1),
|
||||||
|
newMsg(0, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 3,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last empty",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 1,
|
||||||
|
wantMsgLens: []int{1, 0, 0, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last no split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 2,
|
||||||
|
wantMsgLens: []int{1, 1, 0, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(3, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last split last split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(2, 1),
|
||||||
|
newMsg(2, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last split overflow",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(4, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
|
||||||
|
if err != nil && !tt.wantErr {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if got != tt.wantNumEval {
|
||||||
|
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
|
||||||
|
}
|
||||||
|
for i, msg := range tt.msgs {
|
||||||
|
if msg.N != tt.wantMsgLens[i] {
|
||||||
|
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ import (
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn/winrio"
|
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChannelBind struct {
|
type ChannelBind struct {
|
||||||
|
|
12
conn/errors_default.go
Normal file
12
conn/errors_default.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
return false
|
||||||
|
}
|
28
conn/errors_linux.go
Normal file
28
conn/errors_linux.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
var serr *os.SyscallError
|
||||||
|
if errors.As(err, &serr) {
|
||||||
|
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||||
|
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||||
|
// See:
|
||||||
|
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||||
|
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||||
|
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
|
||||||
|
// It occurs when the peer mtu + wg headers is greater than path mtu.
|
||||||
|
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
15
conn/features_default.go
Normal file
15
conn/features_default.go
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
return
|
||||||
|
}
|
31
conn/features_linux.go
Normal file
31
conn/features_linux.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
rc, err := conn.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = rc.Control(func(fd uintptr) {
|
||||||
|
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||||
|
txOffload = errSyscall == nil
|
||||||
|
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
|
||||||
|
// use setsockopt workaround
|
||||||
|
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||||
|
rxOffload = errSyscall == nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return txOffload, rxOffload
|
||||||
|
}
|
21
conn/gso_default.go
Normal file
21
conn/gso_default.go
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
||||||
|
// offloading control data.
|
||||||
|
const gsoControlSize = 0
|
65
conn/gso_linux.go
Normal file
65
conn/gso_linux.go
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
//go:build linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sizeOfGSOData = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||||
|
}
|
||||||
|
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
||||||
|
var gso uint16
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||||
|
return int(gso), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||||
|
// data in control untouched.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
existingLen := len(*control)
|
||||||
|
avail := cap(*control) - existingLen
|
||||||
|
space := unix.CmsgSpace(sizeOfGSOData)
|
||||||
|
if avail < space {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
gsoControl := (*control)[existingLen:]
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||||
|
hdr.Level = unix.SOL_UDP
|
||||||
|
hdr.Type = unix.UDP_SEGMENT
|
||||||
|
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||||
|
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||||
|
*control = (*control)[:existingLen+space]
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling UDP
|
||||||
|
// offloading control data.
|
||||||
|
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
|
@ -21,8 +21,9 @@ func (e *StdNetEndpoint) SrcToString() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
|
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||||
// use alternatively named flags and need ports and require testing.
|
// {get,set}srcControl feature set, but use alternatively named flags and need
|
||||||
|
// ports and require testing.
|
||||||
|
|
||||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
// the source information found.
|
// the source information found.
|
||||||
|
@ -34,8 +35,8 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// srcControlSize returns the recommended buffer size for pooling sticky control
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
// data.
|
// offloading control data.
|
||||||
const srcControlSize = 0
|
const stickyControlSize = 0
|
||||||
|
|
||||||
const StdNetSupportsStickySockets = false
|
const StdNetSupportsStickySockets = false
|
||||||
|
|
|
@ -105,6 +105,8 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
*control = append(*control, ep.src...)
|
*control = append(*control, ep.src...)
|
||||||
}
|
}
|
||||||
|
|
||||||
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
|
// offloading control data.
|
||||||
|
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
||||||
|
|
||||||
const StdNetSupportsStickySockets = true
|
const StdNetSupportsStickySockets = true
|
||||||
|
|
|
@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
|
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
|
||||||
|
|
||||||
control := make([]byte, srcControlSize)
|
control := make([]byte, stickyControlSize)
|
||||||
|
|
||||||
setSrcControl(&control, ep)
|
setSrcControl(&control, ep)
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||||
|
|
||||||
control := make([]byte, srcControlSize)
|
control := make([]byte, stickyControlSize)
|
||||||
|
|
||||||
setSrcControl(&control, ep)
|
setSrcControl(&control, ep)
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
||||||
control := make([]byte, unix.CmsgLen(0))
|
control := make([]byte, stickyControlSize)
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
hdr.Level = 1
|
hdr.Level = 1
|
||||||
hdr.Type = 2
|
hdr.Type = 2
|
||||||
|
@ -129,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
|
||||||
|
|
||||||
func Test_getSrcFromControl(t *testing.T) {
|
func Test_getSrcFromControl(t *testing.T) {
|
||||||
t.Run("IPv4", func(t *testing.T) {
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
control := make([]byte, stickyControlSize)
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
hdr.Level = unix.IPPROTO_IP
|
hdr.Level = unix.IPPROTO_IP
|
||||||
hdr.Type = unix.IP_PKTINFO
|
hdr.Type = unix.IP_PKTINFO
|
||||||
|
@ -149,7 +149,7 @@ func Test_getSrcFromControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("IPv6", func(t *testing.T) {
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
control := make([]byte, stickyControlSize)
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
hdr.Level = unix.IPPROTO_IPV6
|
hdr.Level = unix.IPPROTO_IPV6
|
||||||
hdr.Type = unix.IPV6_PKTINFO
|
hdr.Type = unix.IPV6_PKTINFO
|
||||||
|
|
|
@ -8,7 +8,7 @@ package device
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DummyDatagram struct {
|
type DummyDatagram struct {
|
||||||
|
|
|
@ -19,13 +19,13 @@ import (
|
||||||
// call wg.Done to remove the initial reference.
|
// call wg.Done to remove the initial reference.
|
||||||
// When the refcount hits 0, the queue's channel is closed.
|
// When the refcount hits 0, the queue's channel is closed.
|
||||||
type outboundQueue struct {
|
type outboundQueue struct {
|
||||||
c chan *QueueOutboundElement
|
c chan *QueueOutboundElementsContainer
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOutboundQueue() *outboundQueue {
|
func newOutboundQueue() *outboundQueue {
|
||||||
q := &outboundQueue{
|
q := &outboundQueue{
|
||||||
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||||
}
|
}
|
||||||
q.wg.Add(1)
|
q.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
|
||||||
|
|
||||||
// A inboundQueue is similar to an outboundQueue; see those docs.
|
// A inboundQueue is similar to an outboundQueue; see those docs.
|
||||||
type inboundQueue struct {
|
type inboundQueue struct {
|
||||||
c chan *QueueInboundElement
|
c chan *QueueInboundElementsContainer
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func newInboundQueue() *inboundQueue {
|
func newInboundQueue() *inboundQueue {
|
||||||
q := &inboundQueue{
|
q := &inboundQueue{
|
||||||
c: make(chan *QueueInboundElement, QueueInboundSize),
|
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||||
}
|
}
|
||||||
q.wg.Add(1)
|
q.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
|
||||||
}
|
}
|
||||||
|
|
||||||
type autodrainingInboundQueue struct {
|
type autodrainingInboundQueue struct {
|
||||||
c chan *[]*QueueInboundElement
|
c chan *QueueInboundElementsContainer
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
||||||
|
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
|
||||||
// some other means, such as sending a sentinel nil values.
|
// some other means, such as sending a sentinel nil values.
|
||||||
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||||
q := &autodrainingInboundQueue{
|
q := &autodrainingInboundQueue{
|
||||||
c: make(chan *[]*QueueInboundElement, QueueInboundSize),
|
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||||
}
|
}
|
||||||
runtime.SetFinalizer(q, device.flushInboundQueue)
|
runtime.SetFinalizer(q, device.flushInboundQueue)
|
||||||
return q
|
return q
|
||||||
|
@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||||
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case elems := <-q.c:
|
case elemsContainer := <-q.c:
|
||||||
for _, elem := range *elems {
|
elemsContainer.Lock()
|
||||||
elem.Lock()
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutInboundElement(elem)
|
device.PutInboundElement(elem)
|
||||||
}
|
}
|
||||||
device.PutInboundElementsSlice(elems)
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type autodrainingOutboundQueue struct {
|
type autodrainingOutboundQueue struct {
|
||||||
c chan *[]*QueueOutboundElement
|
c chan *QueueOutboundElementsContainer
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
||||||
|
@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct {
|
||||||
// All sends to the channel must be best-effort, because there may be no receivers.
|
// All sends to the channel must be best-effort, because there may be no receivers.
|
||||||
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||||
q := &autodrainingOutboundQueue{
|
q := &autodrainingOutboundQueue{
|
||||||
c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
|
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||||
}
|
}
|
||||||
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
||||||
return q
|
return q
|
||||||
|
@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||||
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case elems := <-q.c:
|
case elemsContainer := <-q.c:
|
||||||
for _, elem := range *elems {
|
elemsContainer.Lock()
|
||||||
elem.Lock()
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutOutboundElement(elem)
|
device.PutOutboundElement(elem)
|
||||||
}
|
}
|
||||||
device.PutOutboundElementsSlice(elems)
|
device.PutOutboundElementsContainer(elemsContainer)
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
290
device/device.go
290
device/device.go
|
@ -11,11 +11,11 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/ipc"
|
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/ratelimiter"
|
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
"github.com/tevino/abool/v2"
|
"github.com/tevino/abool/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -70,11 +70,11 @@ type Device struct {
|
||||||
cookieChecker CookieChecker
|
cookieChecker CookieChecker
|
||||||
|
|
||||||
pool struct {
|
pool struct {
|
||||||
outboundElementsSlice *WaitPool
|
inboundElementsContainer *WaitPool
|
||||||
inboundElementsSlice *WaitPool
|
outboundElementsContainer *WaitPool
|
||||||
messageBuffers *WaitPool
|
messageBuffers *WaitPool
|
||||||
inboundElements *WaitPool
|
inboundElements *WaitPool
|
||||||
outboundElements *WaitPool
|
outboundElements *WaitPool
|
||||||
}
|
}
|
||||||
|
|
||||||
queue struct {
|
queue struct {
|
||||||
|
@ -95,9 +95,11 @@ type Device struct {
|
||||||
isASecOn abool.AtomicBool
|
isASecOn abool.AtomicBool
|
||||||
aSecMux sync.RWMutex
|
aSecMux sync.RWMutex
|
||||||
aSecCfg aSecCfgType
|
aSecCfg aSecCfgType
|
||||||
|
junkCreator junkCreator
|
||||||
}
|
}
|
||||||
|
|
||||||
type aSecCfgType struct {
|
type aSecCfgType struct {
|
||||||
|
isSet bool
|
||||||
junkPacketCount int
|
junkPacketCount int
|
||||||
junkPacketMinSize int
|
junkPacketMinSize int
|
||||||
junkPacketMaxSize int
|
junkPacketMaxSize int
|
||||||
|
@ -387,10 +389,10 @@ func (device *Device) RemoveAllPeers() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) Close() {
|
func (device *Device) Close() {
|
||||||
device.ipcMutex.Lock()
|
|
||||||
defer device.ipcMutex.Unlock()
|
|
||||||
device.state.Lock()
|
device.state.Lock()
|
||||||
defer device.state.Unlock()
|
defer device.state.Unlock()
|
||||||
|
device.ipcMutex.Lock()
|
||||||
|
defer device.ipcMutex.Unlock()
|
||||||
if device.isClosed() {
|
if device.isClosed() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -414,6 +416,8 @@ func (device *Device) Close() {
|
||||||
|
|
||||||
device.rate.limiter.Close()
|
device.rate.limiter.Close()
|
||||||
|
|
||||||
|
device.resetProtocol()
|
||||||
|
|
||||||
device.log.Verbosef("Device closed")
|
device.log.Verbosef("Device closed")
|
||||||
close(device.closed)
|
close(device.closed)
|
||||||
}
|
}
|
||||||
|
@ -480,11 +484,7 @@ func (device *Device) BindSetMark(mark uint32) error {
|
||||||
// clear cached source addresses
|
// clear cached source addresses
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
@ -534,18 +534,14 @@ func (device *Device) BindUpdate() error {
|
||||||
// clear cached source addresses
|
// clear cached source addresses
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
|
|
||||||
// start receiving routines
|
// start receiving routines
|
||||||
device.net.stopping.Add(len(recvFns))
|
device.net.stopping.Add(len(recvFns))
|
||||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||||
batchSize := netc.bind.BatchSize()
|
batchSize := netc.bind.BatchSize()
|
||||||
for _, fn := range recvFns {
|
for _, fn := range recvFns {
|
||||||
go device.RoutineReceiveIncoming(batchSize, fn)
|
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||||
|
@ -565,7 +561,20 @@ func (device *Device) isAdvancedSecurityOn() bool {
|
||||||
return device.isASecOn.IsSet()
|
return device.isASecOn.IsSet()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (device *Device) resetProtocol() {
|
||||||
|
// restore default message type values
|
||||||
|
MessageInitiationType = 1
|
||||||
|
MessageResponseType = 2
|
||||||
|
MessageCookieReplyType = 3
|
||||||
|
MessageTransportType = 4
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
|
|
||||||
|
if !tempASecCfg.isSet {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
isASecOn := false
|
isASecOn := false
|
||||||
device.aSecMux.Lock()
|
device.aSecMux.Lock()
|
||||||
if tempASecCfg.junkPacketCount < 0 {
|
if tempASecCfg.junkPacketCount < 0 {
|
||||||
|
@ -573,113 +582,115 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
ipc.IpcErrorInvalid,
|
ipc.IpcErrorInvalid,
|
||||||
"JunkPacketCount should be non negative",
|
"JunkPacketCount should be non negative",
|
||||||
)
|
)
|
||||||
} else if tempASecCfg.junkPacketCount > 0 {
|
}
|
||||||
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
|
||||||
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
|
if tempASecCfg.junkPacketCount != 0 {
|
||||||
isASecOn = true
|
isASecOn = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
|
||||||
if tempASecCfg.junkPacketMinSize != 0 {
|
if tempASecCfg.junkPacketMinSize != 0 {
|
||||||
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
|
||||||
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
|
|
||||||
isASecOn = true
|
isASecOn = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if device.aSecCfg.junkPacketCount > 0 &&
|
||||||
|
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
|
||||||
|
|
||||||
|
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
|
||||||
|
device.aSecCfg.junkPacketMinSize = 0
|
||||||
|
device.aSecCfg.junkPacketMaxSize = 1
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
|
||||||
|
tempASecCfg.junkPacketMaxSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
||||||
|
tempASecCfg.junkPacketMaxSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"maxSize: %d; should be greater than minSize: %d; %w",
|
||||||
|
tempASecCfg.junkPacketMaxSize,
|
||||||
|
tempASecCfg.junkPacketMinSize,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"maxSize: %d; should be greater than minSize: %d",
|
||||||
|
tempASecCfg.junkPacketMaxSize,
|
||||||
|
tempASecCfg.junkPacketMinSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
|
||||||
|
}
|
||||||
|
|
||||||
if tempASecCfg.junkPacketMaxSize != 0 {
|
if tempASecCfg.junkPacketMaxSize != 0 {
|
||||||
if tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
|
isASecOn = true
|
||||||
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
|
|
||||||
}
|
|
||||||
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize{
|
|
||||||
device.aSecCfg.junkPacketMinSize = 0
|
|
||||||
device.aSecCfg.junkPacketMaxSize = 1
|
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
|
|
||||||
tempASecCfg.junkPacketMaxSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
|
||||||
tempASecCfg.junkPacketMaxSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
|
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"maxSize: %d; should be greater than minSize: %d; %w",
|
|
||||||
tempASecCfg.junkPacketMaxSize,
|
|
||||||
tempASecCfg.junkPacketMinSize,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"maxSize: %d; should be greater than minSize: %d",
|
|
||||||
tempASecCfg.junkPacketMaxSize,
|
|
||||||
tempASecCfg.junkPacketMinSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
|
||||||
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
|
|
||||||
isASecOn = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||||
|
tempASecCfg.initPacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||||
|
tempASecCfg.initPacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
|
||||||
|
}
|
||||||
|
|
||||||
if tempASecCfg.initPacketJunkSize != 0 {
|
if tempASecCfg.initPacketJunkSize != 0 {
|
||||||
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
|
isASecOn = true
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`init header size(148) + junkSize:%d;
|
|
||||||
should be smaller than maxSegmentSize: %d; %w`,
|
|
||||||
tempASecCfg.initPacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`init header size(148) + junkSize:%d;
|
|
||||||
should be smaller than maxSegmentSize: %d`,
|
|
||||||
tempASecCfg.initPacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
|
||||||
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
|
|
||||||
isASecOn = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if tempASecCfg.responsePacketJunkSize != 0 {
|
|
||||||
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
|
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = ipcErrorf(
|
err = ipcErrorf(
|
||||||
ipc.IpcErrorInvalid,
|
ipc.IpcErrorInvalid,
|
||||||
`response header size(92) + junkSize:%d;
|
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
||||||
should be smaller than maxSegmentSize: %d; %w`,
|
tempASecCfg.responsePacketJunkSize,
|
||||||
tempASecCfg.responsePacketJunkSize,
|
MaxSegmentSize,
|
||||||
MaxSegmentSize,
|
err,
|
||||||
err,
|
)
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`response header size(92) + junkSize:%d;
|
|
||||||
should be smaller than maxSegmentSize: %d`,
|
|
||||||
tempASecCfg.responsePacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
err = ipcErrorf(
|
||||||
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
|
ipc.IpcErrorInvalid,
|
||||||
isASecOn = true
|
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||||
|
tempASecCfg.responsePacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if tempASecCfg.responsePacketJunkSize != 0 {
|
||||||
|
isASecOn = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempASecCfg.initPacketMagicHeader > 4 {
|
if tempASecCfg.initPacketMagicHeader > 4 {
|
||||||
|
@ -691,6 +702,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
device.log.Verbosef("UAPI: Using default init type")
|
device.log.Verbosef("UAPI: Using default init type")
|
||||||
MessageInitiationType = 1
|
MessageInitiationType = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempASecCfg.responsePacketMagicHeader > 4 {
|
if tempASecCfg.responsePacketMagicHeader > 4 {
|
||||||
isASecOn = true
|
isASecOn = true
|
||||||
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||||
|
@ -700,6 +712,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
device.log.Verbosef("UAPI: Using default response type")
|
device.log.Verbosef("UAPI: Using default response type")
|
||||||
MessageResponseType = 2
|
MessageResponseType = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempASecCfg.underloadPacketMagicHeader > 4 {
|
if tempASecCfg.underloadPacketMagicHeader > 4 {
|
||||||
isASecOn = true
|
isASecOn = true
|
||||||
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||||
|
@ -709,6 +722,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
device.log.Verbosef("UAPI: Using default underload type")
|
device.log.Verbosef("UAPI: Using default underload type")
|
||||||
MessageCookieReplyType = 3
|
MessageCookieReplyType = 3
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempASecCfg.transportPacketMagicHeader > 4 {
|
if tempASecCfg.transportPacketMagicHeader > 4 {
|
||||||
isASecOn = true
|
isASecOn = true
|
||||||
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||||
|
@ -716,10 +730,39 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
||||||
} else {
|
} else {
|
||||||
device.log.Verbosef("UAPI: Using default transport type")
|
device.log.Verbosef("UAPI: Using default transport type")
|
||||||
|
|
||||||
MessageTransportType = 4
|
MessageTransportType = 4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isSameMap := map[uint32]bool{}
|
||||||
|
isSameMap[MessageInitiationType] = true
|
||||||
|
isSameMap[MessageResponseType] = true
|
||||||
|
isSameMap[MessageCookieReplyType] = true
|
||||||
|
isSameMap[MessageTransportType] = true
|
||||||
|
|
||||||
|
// size will be different if same values
|
||||||
|
if len(isSameMap) != 4 {
|
||||||
|
if err != nil {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
|
||||||
|
MessageInitiationType,
|
||||||
|
MessageResponseType,
|
||||||
|
MessageCookieReplyType,
|
||||||
|
MessageTransportType,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
|
||||||
|
MessageInitiationType,
|
||||||
|
MessageResponseType,
|
||||||
|
MessageCookieReplyType,
|
||||||
|
MessageTransportType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
|
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
|
||||||
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
|
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
|
||||||
|
|
||||||
|
@ -742,8 +785,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
packetSizeToMsgType = map[int]uint32{
|
packetSizeToMsgType = map[int]uint32{
|
||||||
newInitSize: MessageInitiationType,
|
newInitSize: MessageInitiationType,
|
||||||
newResponseSize: MessageResponseType,
|
newResponseSize: MessageResponseType,
|
||||||
MessageCookieReplySize: MessageCookieReplyType,
|
MessageCookieReplySize: MessageCookieReplyType,
|
||||||
MessageTransportSize: MessageTransportType,
|
MessageTransportSize: MessageTransportType,
|
||||||
}
|
}
|
||||||
|
@ -757,7 +800,8 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
device.isASecOn.SetTo(isASecOn)
|
device.isASecOn.SetTo(isASecOn)
|
||||||
|
device.junkCreator, err = NewJunkCreator(device)
|
||||||
device.aSecMux.Unlock()
|
device.aSecMux.Unlock()
|
||||||
|
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,10 +20,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn/bindtest"
|
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun/tuntest"
|
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
|
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
|
||||||
|
@ -109,7 +109,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
"replace_peers", "true",
|
"replace_peers", "true",
|
||||||
"jc", "5",
|
"jc", "5",
|
||||||
"jmin", "500",
|
"jmin", "500",
|
||||||
"jmax", "501",
|
"jmax", "1000",
|
||||||
"s1", "30",
|
"s1", "30",
|
||||||
"s2", "40",
|
"s2", "40",
|
||||||
"h1", "123456",
|
"h1", "123456",
|
||||||
|
@ -131,7 +131,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
"replace_peers", "true",
|
"replace_peers", "true",
|
||||||
"jc", "5",
|
"jc", "5",
|
||||||
"jmin", "500",
|
"jmin", "500",
|
||||||
"jmax", "501",
|
"jmax", "1000",
|
||||||
"s1", "30",
|
"s1", "30",
|
||||||
"s2", "40",
|
"s2", "40",
|
||||||
"h1", "123456",
|
"h1", "123456",
|
||||||
|
@ -237,7 +237,7 @@ func genTestPair(
|
||||||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||||
level = LogLevelError
|
level = LogLevelError
|
||||||
}
|
}
|
||||||
p.dev = NewDevice(p.tun.TUN(),binds[i],NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
||||||
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
||||||
tb.Errorf("failed to configure device %d: %v", i, err)
|
tb.Errorf("failed to configure device %d: %v", i, err)
|
||||||
p.dev.Close()
|
p.dev.Close()
|
||||||
|
@ -274,7 +274,7 @@ func TestTwoDevicePing(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTwoDevicePingASecurity(t *testing.T) {
|
func TestASecurityTwoDevicePing(t *testing.T) {
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
pair := genTestPair(t, true, true)
|
pair := genTestPair(t, true, true)
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
|
@ -294,7 +294,7 @@ func TestUpDown(t *testing.T) {
|
||||||
pair := genTestPair(t, false, false)
|
pair := genTestPair(t, false, false)
|
||||||
for i := range pair {
|
for i := range pair {
|
||||||
for k := range pair[i].dev.peers.keyMap {
|
for k := range pair[i].dev.peers.keyMap {
|
||||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n",hex.EncodeToString(k[:])))
|
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
@ -513,7 +513,7 @@ func (b *fakeBindSized) Open(
|
||||||
|
|
||||||
func (b *fakeBindSized) Close() error { return nil }
|
func (b *fakeBindSized) Close() error { return nil }
|
||||||
|
|
||||||
func (b *fakeBindSized) SetMark(mark uint32) error {return nil }
|
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
||||||
|
|
||||||
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
||||||
|
|
||||||
|
@ -527,7 +527,9 @@ type fakeTUNDeviceSized struct {
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { return 0, nil }
|
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
||||||
|
|
||||||
|
|
69
device/junk_creator.go
Normal file
69
device/junk_creator.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
v2 "math/rand/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type junkCreator struct {
|
||||||
|
device *Device
|
||||||
|
cha8Rand *v2.ChaCha8
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJunkCreator(d *Device) (junkCreator, error) {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
_, err := crand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return junkCreator{}, err
|
||||||
|
}
|
||||||
|
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be called with aSecMux RLocked
|
||||||
|
func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
|
||||||
|
if jc.device.aSecCfg.junkPacketCount == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
|
||||||
|
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
|
||||||
|
packetSize := jc.randomPacketSize()
|
||||||
|
junk, err := jc.randomJunkWithSize(packetSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to create junk packet: %v", err)
|
||||||
|
}
|
||||||
|
junks = append(junks, junk)
|
||||||
|
}
|
||||||
|
return junks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be called with aSecMux RLocked
|
||||||
|
func (jc *junkCreator) randomPacketSize() int {
|
||||||
|
return int(
|
||||||
|
jc.cha8Rand.Uint64()%uint64(
|
||||||
|
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
|
||||||
|
),
|
||||||
|
) + jc.device.aSecCfg.junkPacketMinSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be called with aSecMux RLocked
|
||||||
|
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
|
||||||
|
headerJunk, err := jc.randomJunkWithSize(size)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create header junk: %v", err)
|
||||||
|
}
|
||||||
|
_, err = writer.Write(headerJunk)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write header junk: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be called with aSecMux RLocked
|
||||||
|
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
|
||||||
|
junk := make([]byte, size)
|
||||||
|
_, err := jc.cha8Rand.Read(junk)
|
||||||
|
return junk, err
|
||||||
|
}
|
124
device/junk_creator_test.go
Normal file
124
device/junk_creator_test.go
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
|
||||||
|
cfg, _ := genASecurityConfigs(t)
|
||||||
|
tun := tuntest.NewChannelTUN()
|
||||||
|
binds := bindtest.NewChannelBinds()
|
||||||
|
level := LogLevelVerbose
|
||||||
|
dev := NewDevice(
|
||||||
|
tun.TUN(),
|
||||||
|
binds[0],
|
||||||
|
NewLogger(level, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := dev.IpcSet(cfg[0]); err != nil {
|
||||||
|
t.Errorf("failed to configure device %v", err)
|
||||||
|
dev.Close()
|
||||||
|
return junkCreator{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
jc, err := NewJunkCreator(dev)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create junk creator %v", err)
|
||||||
|
dev.Close()
|
||||||
|
return junkCreator{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return jc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||||
|
jc, err := setUpJunkCreator(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
got, err := jc.createJunkPackets()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf(
|
||||||
|
"junkCreator.createJunkPackets() = %v; failed",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, junk := range got {
|
||||||
|
key := string(junk)
|
||||||
|
if seen[key] {
|
||||||
|
t.Errorf(
|
||||||
|
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
|
||||||
|
got,
|
||||||
|
junk,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[key] = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
jc, err := setUpJunkCreator(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r1, _ := jc.randomJunkWithSize(10)
|
||||||
|
r2, _ := jc.randomJunkWithSize(10)
|
||||||
|
fmt.Printf("%v\n%v\n", r1, r2)
|
||||||
|
if bytes.Equal(r1, r2) {
|
||||||
|
t.Errorf("same junks %v", err)
|
||||||
|
jc.device.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||||
|
jc, err := setUpJunkCreator(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for range [30]struct{}{} {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
|
||||||
|
got > jc.device.aSecCfg.junkPacketMaxSize {
|
||||||
|
t.Errorf(
|
||||||
|
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
|
||||||
|
got,
|
||||||
|
jc.device.aSecCfg.junkPacketMinSize,
|
||||||
|
jc.device.aSecCfg.junkPacketMaxSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_junkCreator_appendJunk(t *testing.T) {
|
||||||
|
jc, err := setUpJunkCreator(t)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
s := "apple"
|
||||||
|
buffer := bytes.NewBuffer([]byte(s))
|
||||||
|
err := jc.appendJunk(buffer, 30)
|
||||||
|
if err != nil &&
|
||||||
|
buffer.Len() != len(s)+30 {
|
||||||
|
t.Errorf("appendWithJunk() size don't match")
|
||||||
|
}
|
||||||
|
read := make([]byte, 50)
|
||||||
|
buffer.Read(read)
|
||||||
|
fmt.Println(string(read))
|
||||||
|
})
|
||||||
|
}
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/replay"
|
"github.com/amnezia-vpn/amneziawg-go/replay"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Due to limitations in Go and /x/crypto there is currently
|
/* Due to limitations in Go and /x/crypto there is currently
|
||||||
|
|
|
@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
||||||
device.net.brokenRoaming = true
|
device.net.brokenRoaming = true
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Lock()
|
peer.endpoint.Lock()
|
||||||
peer.disableRoaming = peer.endpoint != nil
|
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
||||||
peer.Unlock()
|
peer.endpoint.Unlock()
|
||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/poly1305"
|
"golang.org/x/crypto/poly1305"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tai64n"
|
"github.com/amnezia-vpn/amneziawg-go/tai64n"
|
||||||
)
|
)
|
||||||
|
|
||||||
type handshakeState int
|
type handshakeState int
|
||||||
|
|
|
@ -10,8 +10,8 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun/tuntest"
|
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCurveWrappers(t *testing.T) {
|
func TestCurveWrappers(t *testing.T) {
|
||||||
|
|
|
@ -12,22 +12,25 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
isRunning atomic.Bool
|
isRunning atomic.Bool
|
||||||
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
|
||||||
keypairs Keypairs
|
keypairs Keypairs
|
||||||
handshake Handshake
|
handshake Handshake
|
||||||
device *Device
|
device *Device
|
||||||
endpoint conn.Endpoint
|
|
||||||
stopping sync.WaitGroup // routines pending stop
|
stopping sync.WaitGroup // routines pending stop
|
||||||
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||||
rxBytes atomic.Uint64 // bytes received from peer
|
rxBytes atomic.Uint64 // bytes received from peer
|
||||||
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||||
|
|
||||||
disableRoaming bool
|
endpoint struct {
|
||||||
|
sync.Mutex
|
||||||
|
val conn.Endpoint
|
||||||
|
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
||||||
|
disableRoaming bool
|
||||||
|
}
|
||||||
|
|
||||||
timers struct {
|
timers struct {
|
||||||
retransmitHandshake *Timer
|
retransmitHandshake *Timer
|
||||||
|
@ -45,9 +48,9 @@ type Peer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
queue struct {
|
queue struct {
|
||||||
staged chan *[]*QueueOutboundElement // staged packets before a handshake is available
|
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
|
||||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieGenerator CookieGenerator
|
cookieGenerator CookieGenerator
|
||||||
|
@ -74,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
|
|
||||||
// create peer
|
// create peer
|
||||||
peer := new(Peer)
|
peer := new(Peer)
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
|
|
||||||
peer.cookieGenerator.Init(pk)
|
peer.cookieGenerator.Init(pk)
|
||||||
peer.device = device
|
peer.device = device
|
||||||
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
||||||
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
||||||
peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize)
|
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
|
||||||
|
|
||||||
// map public key
|
// map public key
|
||||||
_, ok := device.peers.keyMap[pk]
|
_, ok := device.peers.keyMap[pk]
|
||||||
|
@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
// reset endpoint
|
// reset endpoint
|
||||||
peer.endpoint = nil
|
peer.endpoint.Lock()
|
||||||
|
peer.endpoint.val = nil
|
||||||
|
peer.endpoint.disableRoaming = false
|
||||||
|
peer.endpoint.clearSrcOnTx = false
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
// init timers
|
// init timers
|
||||||
peer.timersInit()
|
peer.timersInit()
|
||||||
|
@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.RLock()
|
peer.endpoint.Lock()
|
||||||
defer peer.RUnlock()
|
endpoint := peer.endpoint.val
|
||||||
|
if endpoint == nil {
|
||||||
if peer.endpoint == nil {
|
peer.endpoint.Unlock()
|
||||||
return errors.New("no known endpoint for peer")
|
return errors.New("no known endpoint for peer")
|
||||||
}
|
}
|
||||||
|
if peer.endpoint.clearSrcOnTx {
|
||||||
|
endpoint.ClearSrc()
|
||||||
|
peer.endpoint.clearSrcOnTx = false
|
||||||
|
}
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
err := peer.device.net.bind.Send(buffers, peer.endpoint)
|
err := peer.device.net.bind.Send(buffers, endpoint)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var totalLen uint64
|
var totalLen uint64
|
||||||
for _, b := range buffers {
|
for _, b := range buffers {
|
||||||
|
@ -267,10 +277,20 @@ func (peer *Peer) Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||||
if peer.disableRoaming {
|
peer.endpoint.Lock()
|
||||||
|
defer peer.endpoint.Unlock()
|
||||||
|
if peer.endpoint.disableRoaming {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peer.Lock()
|
peer.endpoint.clearSrcOnTx = false
|
||||||
peer.endpoint = endpoint
|
peer.endpoint.val = endpoint
|
||||||
peer.Unlock()
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) markEndpointSrcForClearing() {
|
||||||
|
peer.endpoint.Lock()
|
||||||
|
defer peer.endpoint.Unlock()
|
||||||
|
if peer.endpoint.val == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
peer.endpoint.clearSrcOnTx = true
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,13 +46,13 @@ func (p *WaitPool) Put(x any) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) PopulatePools() {
|
func (device *Device) PopulatePools() {
|
||||||
device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
|
||||||
return &s
|
|
||||||
})
|
|
||||||
device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
|
||||||
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
||||||
return &s
|
return &QueueInboundElementsContainer{elems: s}
|
||||||
|
})
|
||||||
|
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
|
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
||||||
|
return &QueueOutboundElementsContainer{elems: s}
|
||||||
})
|
})
|
||||||
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
return new([MaxMessageSize]byte)
|
return new([MaxMessageSize]byte)
|
||||||
|
@ -65,28 +65,32 @@ func (device *Device) PopulatePools() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
|
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
|
||||||
return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
|
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
|
||||||
|
c.Mutex = sync.Mutex{}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
|
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
|
||||||
for i := range *s {
|
for i := range c.elems {
|
||||||
(*s)[i] = nil
|
c.elems[i] = nil
|
||||||
}
|
}
|
||||||
*s = (*s)[:0]
|
c.elems = c.elems[:0]
|
||||||
device.pool.outboundElementsSlice.Put(s)
|
device.pool.inboundElementsContainer.Put(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
|
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
|
||||||
return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
|
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
|
||||||
|
c.Mutex = sync.Mutex{}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
|
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
|
||||||
for i := range *s {
|
for i := range c.elems {
|
||||||
(*s)[i] = nil
|
c.elems[i] = nil
|
||||||
}
|
}
|
||||||
*s = (*s)[:0]
|
c.elems = c.elems[:0]
|
||||||
device.pool.inboundElementsSlice.Put(s)
|
device.pool.outboundElementsContainer.Put(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import "github.com/amnezia-vpn/amnezia-wg/conn"
|
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
|
||||||
/* Reduce memory consumption for Android */
|
/* Reduce memory consumption for Android */
|
||||||
|
|
||||||
|
@ -14,6 +14,6 @@ const (
|
||||||
QueueOutboundSize = 1024
|
QueueOutboundSize = 1024
|
||||||
QueueInboundSize = 1024
|
QueueInboundSize = 1024
|
||||||
QueueHandshakeSize = 1024
|
QueueHandshakeSize = 1024
|
||||||
MaxSegmentSize = 2200
|
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
||||||
PreallocatedBuffersPerPool = 4096
|
PreallocatedBuffersPerPool = 4096
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import "github.com/amnezia-vpn/amnezia-wg/conn"
|
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
QueueStagedSize = conn.IdealBatchSize
|
QueueStagedSize = conn.IdealBatchSize
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
@ -27,7 +27,6 @@ type QueueHandshakeElement struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueueInboundElement struct {
|
type QueueInboundElement struct {
|
||||||
sync.Mutex
|
|
||||||
buffer *[MaxMessageSize]byte
|
buffer *[MaxMessageSize]byte
|
||||||
packet []byte
|
packet []byte
|
||||||
counter uint64
|
counter uint64
|
||||||
|
@ -35,6 +34,11 @@ type QueueInboundElement struct {
|
||||||
endpoint conn.Endpoint
|
endpoint conn.Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueueInboundElementsContainer struct {
|
||||||
|
sync.Mutex
|
||||||
|
elems []*QueueInboundElement
|
||||||
|
}
|
||||||
|
|
||||||
// clearPointers clears elem fields that contain pointers.
|
// clearPointers clears elem fields that contain pointers.
|
||||||
// This makes the garbage collector's life easier and
|
// This makes the garbage collector's life easier and
|
||||||
// avoids accidentally keeping other objects around unnecessarily.
|
// avoids accidentally keeping other objects around unnecessarily.
|
||||||
|
@ -90,7 +94,7 @@ func (device *Device) RoutineReceiveIncoming(
|
||||||
count int
|
count int
|
||||||
endpoints = make([]conn.Endpoint, maxBatchSize)
|
endpoints = make([]conn.Endpoint, maxBatchSize)
|
||||||
deathSpiral int
|
deathSpiral int
|
||||||
elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
|
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
||||||
)
|
)
|
||||||
|
|
||||||
for i := range bufsArrs {
|
for i := range bufsArrs {
|
||||||
|
@ -141,10 +145,11 @@ func (device *Device) RoutineReceiveIncoming(
|
||||||
junkSize := msgTypeToJunkSize[assumedMsgType]
|
junkSize := msgTypeToJunkSize[assumedMsgType]
|
||||||
// transport size can align with other header types;
|
// transport size can align with other header types;
|
||||||
// making sure we have the right msgType
|
// making sure we have the right msgType
|
||||||
msgType = binary.LittleEndian.Uint32(packet[junkSize:4])
|
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
|
||||||
if msgType == assumedMsgType {
|
if msgType == assumedMsgType {
|
||||||
packet = packet[junkSize:]
|
packet = packet[junkSize:]
|
||||||
} else {
|
} else {
|
||||||
|
device.log.Verbosef("Transport packet lined up with another msg type")
|
||||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -194,15 +199,14 @@ func (device *Device) RoutineReceiveIncoming(
|
||||||
elem.keypair = keypair
|
elem.keypair = keypair
|
||||||
elem.endpoint = endpoints[i]
|
elem.endpoint = endpoints[i]
|
||||||
elem.counter = 0
|
elem.counter = 0
|
||||||
elem.Mutex = sync.Mutex{}
|
|
||||||
elem.Lock()
|
|
||||||
|
|
||||||
elemsForPeer, ok := elemsByPeer[peer]
|
elemsForPeer, ok := elemsByPeer[peer]
|
||||||
if !ok {
|
if !ok {
|
||||||
elemsForPeer = device.GetInboundElementsSlice()
|
elemsForPeer = device.GetInboundElementsContainer()
|
||||||
|
elemsForPeer.Lock()
|
||||||
elemsByPeer[peer] = elemsForPeer
|
elemsByPeer[peer] = elemsForPeer
|
||||||
}
|
}
|
||||||
*elemsForPeer = append(*elemsForPeer, elem)
|
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||||
bufsArrs[i] = device.GetMessageBuffer()
|
bufsArrs[i] = device.GetMessageBuffer()
|
||||||
bufs[i] = bufsArrs[i][:]
|
bufs[i] = bufsArrs[i][:]
|
||||||
continue
|
continue
|
||||||
|
@ -242,18 +246,16 @@ func (device *Device) RoutineReceiveIncoming(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
device.aSecMux.RUnlock()
|
device.aSecMux.RUnlock()
|
||||||
for peer, elems := range elemsByPeer {
|
for peer, elemsContainer := range elemsByPeer {
|
||||||
if peer.isRunning.Load() {
|
if peer.isRunning.Load() {
|
||||||
peer.queue.inbound.c <- elems
|
peer.queue.inbound.c <- elemsContainer
|
||||||
for _, elem := range *elems {
|
device.queue.decryption.c <- elemsContainer
|
||||||
device.queue.decryption.c <- elem
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for _, elem := range *elems {
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutInboundElement(elem)
|
device.PutInboundElement(elem)
|
||||||
}
|
}
|
||||||
device.PutInboundElementsSlice(elems)
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
}
|
}
|
||||||
delete(elemsByPeer, peer)
|
delete(elemsByPeer, peer)
|
||||||
}
|
}
|
||||||
|
@ -266,26 +268,28 @@ func (device *Device) RoutineDecryption(id int) {
|
||||||
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
||||||
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
||||||
|
|
||||||
for elem := range device.queue.decryption.c {
|
for elemsContainer := range device.queue.decryption.c {
|
||||||
// split message into fields
|
for _, elem := range elemsContainer.elems {
|
||||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
// split message into fields
|
||||||
content := elem.packet[MessageTransportOffsetContent:]
|
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||||
|
content := elem.packet[MessageTransportOffsetContent:]
|
||||||
|
|
||||||
// decrypt and release to consumer
|
// decrypt and release to consumer
|
||||||
var err error
|
var err error
|
||||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||||
// copy counter to nonce
|
// copy counter to nonce
|
||||||
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
||||||
elem.packet, err = elem.keypair.receive.Open(
|
elem.packet, err = elem.keypair.receive.Open(
|
||||||
content[:0],
|
content[:0],
|
||||||
nonce[:],
|
nonce[:],
|
||||||
content,
|
content,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
elem.packet = nil
|
elem.packet = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
elem.Unlock()
|
elemsContainer.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -467,12 +471,15 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||||
|
|
||||||
bufs := make([][]byte, 0, maxBatchSize)
|
bufs := make([][]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
for elems := range peer.queue.inbound.c {
|
for elemsContainer := range peer.queue.inbound.c {
|
||||||
if elems == nil {
|
if elemsContainer == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, elem := range *elems {
|
elemsContainer.Lock()
|
||||||
elem.Lock()
|
validTailPacket := -1
|
||||||
|
dataPacketReceived := false
|
||||||
|
rxBytesLen := uint64(0)
|
||||||
|
for i, elem := range elemsContainer.elems {
|
||||||
if elem.packet == nil {
|
if elem.packet == nil {
|
||||||
// decryption failed
|
// decryption failed
|
||||||
continue
|
continue
|
||||||
|
@ -482,21 +489,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
validTailPacket = i
|
||||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||||
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
peer.timersHandshakeComplete()
|
peer.timersHandshakeComplete()
|
||||||
peer.SendStagedPackets()
|
peer.SendStagedPackets()
|
||||||
}
|
}
|
||||||
peer.keepKeyFreshReceiving()
|
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
|
||||||
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
|
||||||
|
|
||||||
if len(elem.packet) == 0 {
|
if len(elem.packet) == 0 {
|
||||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
peer.timersDataReceived()
|
dataPacketReceived = true
|
||||||
|
|
||||||
switch elem.packet[0] >> 4 {
|
switch elem.packet[0] >> 4 {
|
||||||
case 4:
|
case 4:
|
||||||
|
@ -545,17 +550,28 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||||
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
|
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peer.rxBytes.Add(rxBytesLen)
|
||||||
|
if validTailPacket >= 0 {
|
||||||
|
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
||||||
|
peer.keepKeyFreshReceiving()
|
||||||
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
|
peer.timersAnyAuthenticatedPacketReceived()
|
||||||
|
}
|
||||||
|
if dataPacketReceived {
|
||||||
|
peer.timersDataReceived()
|
||||||
|
}
|
||||||
if len(bufs) > 0 {
|
if len(bufs) > 0 {
|
||||||
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
||||||
if err != nil && !device.isClosed() {
|
if err != nil && !device.isClosed() {
|
||||||
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, elem := range *elems {
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutInboundElement(elem)
|
device.PutInboundElement(elem)
|
||||||
}
|
}
|
||||||
bufs = bufs[:0]
|
bufs = bufs[:0]
|
||||||
device.PutInboundElementsSlice(elems)
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
193
device/send.go
193
device/send.go
|
@ -9,13 +9,13 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
@ -46,7 +46,6 @@ import (
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type QueueOutboundElement struct {
|
type QueueOutboundElement struct {
|
||||||
sync.Mutex
|
|
||||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
buffer *[MaxMessageSize]byte // slice holding the packet data
|
||||||
packet []byte // slice of "buffer" (always!)
|
packet []byte // slice of "buffer" (always!)
|
||||||
nonce uint64 // nonce for encryption
|
nonce uint64 // nonce for encryption
|
||||||
|
@ -54,10 +53,14 @@ type QueueOutboundElement struct {
|
||||||
peer *Peer // related peer
|
peer *Peer // related peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueueOutboundElementsContainer struct {
|
||||||
|
sync.Mutex
|
||||||
|
elems []*QueueOutboundElement
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
||||||
elem := device.GetOutboundElement()
|
elem := device.GetOutboundElement()
|
||||||
elem.buffer = device.GetMessageBuffer()
|
elem.buffer = device.GetMessageBuffer()
|
||||||
elem.Mutex = sync.Mutex{}
|
|
||||||
elem.nonce = 0
|
elem.nonce = 0
|
||||||
// keypair and peer were cleared (if necessary) by clearPointers.
|
// keypair and peer were cleared (if necessary) by clearPointers.
|
||||||
return elem
|
return elem
|
||||||
|
@ -79,15 +82,15 @@ func (elem *QueueOutboundElement) clearPointers() {
|
||||||
func (peer *Peer) SendKeepalive() {
|
func (peer *Peer) SendKeepalive() {
|
||||||
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||||
elem := peer.device.NewOutboundElement()
|
elem := peer.device.NewOutboundElement()
|
||||||
elems := peer.device.GetOutboundElementsSlice()
|
elemsContainer := peer.device.GetOutboundElementsContainer()
|
||||||
*elems = append(*elems, elem)
|
elemsContainer.elems = append(elemsContainer.elems, elem)
|
||||||
select {
|
select {
|
||||||
case peer.queue.staged <- elems:
|
case peer.queue.staged <- elemsContainer:
|
||||||
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
||||||
default:
|
default:
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
peer.device.PutOutboundElement(elem)
|
peer.device.PutOutboundElement(elem)
|
||||||
peer.device.PutOutboundElementsSlice(elems)
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
peer.SendStagedPackets()
|
peer.SendStagedPackets()
|
||||||
|
@ -125,26 +128,38 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
var junkedHeader []byte
|
var junkedHeader []byte
|
||||||
if peer.device.isAdvancedSecurityOn() {
|
if peer.device.isAdvancedSecurityOn() {
|
||||||
peer.device.aSecMux.RLock()
|
peer.device.aSecMux.RLock()
|
||||||
junks, err := peer.createJunkPackets()
|
junks, err := peer.device.junkCreator.createJunkPackets()
|
||||||
|
peer.device.aSecMux.RUnlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.aSecMux.RUnlock()
|
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sendBuffer = append(sendBuffer, junks...)
|
|
||||||
|
if len(junks) > 0 {
|
||||||
|
err = peer.SendBuffers(junks)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.device.aSecMux.RLock()
|
||||||
if peer.device.aSecCfg.initPacketJunkSize != 0 {
|
if peer.device.aSecCfg.initPacketJunkSize != 0 {
|
||||||
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
|
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
|
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.aSecMux.RUnlock()
|
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
|
peer.device.aSecMux.RUnlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
junkedHeader = writer.Bytes()
|
junkedHeader = writer.Bytes()
|
||||||
}
|
}
|
||||||
peer.device.aSecMux.RUnlock()
|
peer.device.aSecMux.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf [MessageInitiationSize]byte
|
var buf [MessageInitiationSize]byte
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, msg)
|
binary.Write(writer, binary.LittleEndian, msg)
|
||||||
|
@ -184,7 +199,7 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||||
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
|
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
|
||||||
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
|
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
|
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.aSecMux.RUnlock()
|
peer.device.aSecMux.RUnlock()
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
|
@ -269,7 +284,7 @@ func (device *Device) RoutineReadFromTUN() {
|
||||||
readErr error
|
readErr error
|
||||||
elems = make([]*QueueOutboundElement, batchSize)
|
elems = make([]*QueueOutboundElement, batchSize)
|
||||||
bufs = make([][]byte, batchSize)
|
bufs = make([][]byte, batchSize)
|
||||||
elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
|
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
|
||||||
count = 0
|
count = 0
|
||||||
sizes = make([]int, batchSize)
|
sizes = make([]int, batchSize)
|
||||||
offset = MessageTransportHeaderSize
|
offset = MessageTransportHeaderSize
|
||||||
|
@ -326,10 +341,10 @@ func (device *Device) RoutineReadFromTUN() {
|
||||||
}
|
}
|
||||||
elemsForPeer, ok := elemsByPeer[peer]
|
elemsForPeer, ok := elemsByPeer[peer]
|
||||||
if !ok {
|
if !ok {
|
||||||
elemsForPeer = device.GetOutboundElementsSlice()
|
elemsForPeer = device.GetOutboundElementsContainer()
|
||||||
elemsByPeer[peer] = elemsForPeer
|
elemsByPeer[peer] = elemsForPeer
|
||||||
}
|
}
|
||||||
*elemsForPeer = append(*elemsForPeer, elem)
|
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||||
elems[i] = device.NewOutboundElement()
|
elems[i] = device.NewOutboundElement()
|
||||||
bufs[i] = elems[i].buffer[:]
|
bufs[i] = elems[i].buffer[:]
|
||||||
}
|
}
|
||||||
|
@ -339,11 +354,11 @@ func (device *Device) RoutineReadFromTUN() {
|
||||||
peer.StagePackets(elemsForPeer)
|
peer.StagePackets(elemsForPeer)
|
||||||
peer.SendStagedPackets()
|
peer.SendStagedPackets()
|
||||||
} else {
|
} else {
|
||||||
for _, elem := range *elemsForPeer {
|
for _, elem := range elemsForPeer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutOutboundElement(elem)
|
device.PutOutboundElement(elem)
|
||||||
}
|
}
|
||||||
device.PutOutboundElementsSlice(elemsForPeer)
|
device.PutOutboundElementsContainer(elemsForPeer)
|
||||||
}
|
}
|
||||||
delete(elemsByPeer, peer)
|
delete(elemsByPeer, peer)
|
||||||
}
|
}
|
||||||
|
@ -367,7 +382,7 @@ func (device *Device) RoutineReadFromTUN() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
|
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case peer.queue.staged <- elems:
|
case peer.queue.staged <- elems:
|
||||||
|
@ -376,11 +391,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case tooOld := <-peer.queue.staged:
|
case tooOld := <-peer.queue.staged:
|
||||||
for _, elem := range *tooOld {
|
for _, elem := range tooOld.elems {
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
peer.device.PutOutboundElement(elem)
|
peer.device.PutOutboundElement(elem)
|
||||||
}
|
}
|
||||||
peer.device.PutOutboundElementsSlice(tooOld)
|
peer.device.PutOutboundElementsContainer(tooOld)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -399,54 +414,52 @@ top:
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var elemsOOO *[]*QueueOutboundElement
|
var elemsContainerOOO *QueueOutboundElementsContainer
|
||||||
select {
|
select {
|
||||||
case elems := <-peer.queue.staged:
|
case elemsContainer := <-peer.queue.staged:
|
||||||
i := 0
|
i := 0
|
||||||
for _, elem := range *elems {
|
for _, elem := range elemsContainer.elems {
|
||||||
elem.peer = peer
|
elem.peer = peer
|
||||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||||
if elem.nonce >= RejectAfterMessages {
|
if elem.nonce >= RejectAfterMessages {
|
||||||
keypair.sendNonce.Store(RejectAfterMessages)
|
keypair.sendNonce.Store(RejectAfterMessages)
|
||||||
if elemsOOO == nil {
|
if elemsContainerOOO == nil {
|
||||||
elemsOOO = peer.device.GetOutboundElementsSlice()
|
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
|
||||||
}
|
}
|
||||||
*elemsOOO = append(*elemsOOO, elem)
|
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
(*elems)[i] = elem
|
elemsContainer.elems[i] = elem
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
elem.keypair = keypair
|
elem.keypair = keypair
|
||||||
elem.Lock()
|
|
||||||
}
|
}
|
||||||
*elems = (*elems)[:i]
|
elemsContainer.Lock()
|
||||||
|
elemsContainer.elems = elemsContainer.elems[:i]
|
||||||
|
|
||||||
if elemsOOO != nil {
|
if elemsContainerOOO != nil {
|
||||||
peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
|
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(*elems) == 0 {
|
if len(elemsContainer.elems) == 0 {
|
||||||
peer.device.PutOutboundElementsSlice(elems)
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
goto top
|
goto top
|
||||||
}
|
}
|
||||||
|
|
||||||
// add to parallel and sequential queue
|
// add to parallel and sequential queue
|
||||||
if peer.isRunning.Load() {
|
if peer.isRunning.Load() {
|
||||||
peer.queue.outbound.c <- elems
|
peer.queue.outbound.c <- elemsContainer
|
||||||
for _, elem := range *elems {
|
peer.device.queue.encryption.c <- elemsContainer
|
||||||
peer.device.queue.encryption.c <- elem
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for _, elem := range *elems {
|
for _, elem := range elemsContainer.elems {
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
peer.device.PutOutboundElement(elem)
|
peer.device.PutOutboundElement(elem)
|
||||||
}
|
}
|
||||||
peer.device.PutOutboundElementsSlice(elems)
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
}
|
}
|
||||||
|
|
||||||
if elemsOOO != nil {
|
if elemsContainerOOO != nil {
|
||||||
goto top
|
goto top
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -455,40 +468,15 @@ top:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) createJunkPackets() ([][]byte, error) {
|
|
||||||
if peer.device.aSecCfg.junkPacketCount == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
|
|
||||||
for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
|
|
||||||
packetSize := rand.Intn(
|
|
||||||
peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
|
|
||||||
) + peer.device.aSecCfg.junkPacketMinSize
|
|
||||||
|
|
||||||
junk, err := randomJunkWithSize(packetSize)
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf(
|
|
||||||
"%v - Failed to create junk packet: %v",
|
|
||||||
peer,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
junks = append(junks, junk)
|
|
||||||
}
|
|
||||||
return junks, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) FlushStagedPackets() {
|
func (peer *Peer) FlushStagedPackets() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case elems := <-peer.queue.staged:
|
case elemsContainer := <-peer.queue.staged:
|
||||||
for _, elem := range *elems {
|
for _, elem := range elemsContainer.elems {
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
peer.device.PutOutboundElement(elem)
|
peer.device.PutOutboundElement(elem)
|
||||||
}
|
}
|
||||||
peer.device.PutOutboundElementsSlice(elems)
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -522,30 +510,34 @@ func (device *Device) RoutineEncryption(id int) {
|
||||||
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
||||||
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
||||||
|
|
||||||
for elem := range device.queue.encryption.c {
|
for elemsContainer := range device.queue.encryption.c {
|
||||||
// populate header fields
|
for _, elem := range elemsContainer.elems {
|
||||||
header := elem.buffer[:MessageTransportHeaderSize]
|
// populate header fields
|
||||||
|
header := elem.buffer[:MessageTransportHeaderSize]
|
||||||
|
|
||||||
fieldType := header[0:4]
|
fieldType := header[0:4]
|
||||||
fieldReceiver := header[4:8]
|
fieldReceiver := header[4:8]
|
||||||
fieldNonce := header[8:16]
|
fieldNonce := header[8:16]
|
||||||
|
|
||||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
||||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
||||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||||
|
|
||||||
// pad content to multiple of 16
|
// pad content to multiple of 16
|
||||||
paddingSize := calculatePaddingSize(
|
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||||
len(elem.packet),
|
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||||
int(device.tun.mtu.Load()),
|
|
||||||
)
|
|
||||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
|
||||||
|
|
||||||
// encrypt content and release to consumer
|
// encrypt content and release to consumer
|
||||||
|
|
||||||
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
||||||
elem.packet = elem.keypair.send.Seal(header, nonce[:], elem.packet, nil)
|
elem.packet = elem.keypair.send.Seal(
|
||||||
elem.Unlock()
|
header,
|
||||||
|
nonce[:],
|
||||||
|
elem.packet,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
elemsContainer.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -559,9 +551,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||||
|
|
||||||
bufs := make([][]byte, 0, maxBatchSize)
|
bufs := make([][]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
for elems := range peer.queue.outbound.c {
|
for elemsContainer := range peer.queue.outbound.c {
|
||||||
bufs = bufs[:0]
|
bufs = bufs[:0]
|
||||||
if elems == nil {
|
if elemsContainer == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !peer.isRunning.Load() {
|
if !peer.isRunning.Load() {
|
||||||
|
@ -571,16 +563,16 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||||
// The timers and SendBuffers code are resilient to a few stragglers.
|
// The timers and SendBuffers code are resilient to a few stragglers.
|
||||||
// TODO: rework peer shutdown order to ensure
|
// TODO: rework peer shutdown order to ensure
|
||||||
// that we never accidentally keep timers alive longer than necessary.
|
// that we never accidentally keep timers alive longer than necessary.
|
||||||
for _, elem := range *elems {
|
elemsContainer.Lock()
|
||||||
elem.Lock()
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutOutboundElement(elem)
|
device.PutOutboundElement(elem)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
dataSent := false
|
dataSent := false
|
||||||
for _, elem := range *elems {
|
elemsContainer.Lock()
|
||||||
elem.Lock()
|
for _, elem := range elemsContainer.elems {
|
||||||
if len(elem.packet) != MessageKeepaliveSize {
|
if len(elem.packet) != MessageKeepaliveSize {
|
||||||
dataSent = true
|
dataSent = true
|
||||||
}
|
}
|
||||||
|
@ -594,11 +586,18 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||||
if dataSent {
|
if dataSent {
|
||||||
peer.timersDataSent()
|
peer.timersDataSent()
|
||||||
}
|
}
|
||||||
for _, elem := range *elems {
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutOutboundElement(elem)
|
device.PutOutboundElement(elem)
|
||||||
}
|
}
|
||||||
device.PutOutboundElementsSlice(elems)
|
device.PutOutboundElementsContainer(elemsContainer)
|
||||||
|
if err != nil {
|
||||||
|
var errGSO conn.ErrUDPGSODisabled
|
||||||
|
if errors.As(err, &errGSO) {
|
||||||
|
device.log.Verbosef(err.Error())
|
||||||
|
err = errGSO.RetryErr
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
|
|
@ -20,8 +20,8 @@ import (
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||||
if !ok {
|
if !ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
pePtr.peer.Lock()
|
pePtr.peer.endpoint.Lock()
|
||||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
||||||
pePtr.peer.Unlock()
|
pePtr.peer.endpoint.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||||
pePtr.peer.Unlock()
|
pePtr.peer.endpoint.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
|
pePtr.peer.endpoint.clearSrcOnTx = true
|
||||||
pePtr.peer.Unlock()
|
pePtr.peer.endpoint.Unlock()
|
||||||
}
|
}
|
||||||
attr = attr[attrhdr.Len:]
|
attr = attr[attrhdr.Len:]
|
||||||
}
|
}
|
||||||
|
@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
i := uint32(1)
|
i := uint32(1)
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.RLock()
|
peer.endpoint.Lock()
|
||||||
if peer.endpoint == nil {
|
if peer.endpoint.val == nil {
|
||||||
peer.RUnlock()
|
peer.endpoint.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
|
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
||||||
if nativeEP == nil {
|
if nativeEP == nil {
|
||||||
peer.RUnlock()
|
peer.endpoint.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
||||||
peer.RUnlock()
|
peer.endpoint.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
nlmsg := struct {
|
nlmsg := struct {
|
||||||
|
@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||||
reqPeerLock.Lock()
|
reqPeerLock.Lock()
|
||||||
reqPeer[i] = peerEndpointPtr{
|
reqPeer[i] = peerEndpointPtr{
|
||||||
peer: peer,
|
peer: peer,
|
||||||
endpoint: &peer.endpoint,
|
endpoint: &peer.endpoint.val,
|
||||||
}
|
}
|
||||||
reqPeerLock.Unlock()
|
reqPeerLock.Unlock()
|
||||||
peer.RUnlock()
|
peer.endpoint.Unlock()
|
||||||
i++
|
i++
|
||||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) {
|
||||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||||
|
|
||||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
peer.Unlock()
|
|
||||||
|
|
||||||
peer.SendHandshakeInitiation(true)
|
peer.SendHandshakeInitiation(true)
|
||||||
}
|
}
|
||||||
|
@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) {
|
||||||
func expiredNewHandshake(peer *Peer) {
|
func expiredNewHandshake(peer *Peer) {
|
||||||
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
||||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
peer.Unlock()
|
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ package device
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMTU = 1420
|
const DefaultMTU = 1420
|
||||||
|
|
115
device/uapi.go
115
device/uapi.go
|
@ -18,7 +18,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/ipc"
|
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPCError struct {
|
type IPCError struct {
|
||||||
|
@ -129,36 +129,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
|
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
// Serialize peer state.
|
// Serialize peer state.
|
||||||
// Do the work in an anonymous function so that we can use defer.
|
peer.handshake.mutex.RLock()
|
||||||
func() {
|
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||||
peer.RLock()
|
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||||
defer peer.RUnlock()
|
peer.handshake.mutex.RUnlock()
|
||||||
|
sendf("protocol_version=1")
|
||||||
|
peer.endpoint.Lock()
|
||||||
|
if peer.endpoint.val != nil {
|
||||||
|
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
||||||
|
}
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
nano := peer.lastHandshakeNano.Load()
|
||||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
secs := nano / time.Second.Nanoseconds()
|
||||||
sendf("protocol_version=1")
|
nano %= time.Second.Nanoseconds()
|
||||||
if peer.endpoint != nil {
|
|
||||||
sendf("endpoint=%s", peer.endpoint.DstToString())
|
|
||||||
}
|
|
||||||
|
|
||||||
nano := peer.lastHandshakeNano.Load()
|
sendf("last_handshake_time_sec=%d", secs)
|
||||||
secs := nano / time.Second.Nanoseconds()
|
sendf("last_handshake_time_nsec=%d", nano)
|
||||||
nano %= time.Second.Nanoseconds()
|
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||||
|
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||||
|
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||||
|
|
||||||
sendf("last_handshake_time_sec=%d", secs)
|
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||||
sendf("last_handshake_time_nsec=%d", nano)
|
sendf("allowed_ip=%s", prefix.String())
|
||||||
sendf("tx_bytes=%d", peer.txBytes.Load())
|
return true
|
||||||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
})
|
||||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
|
||||||
|
|
||||||
device.allowedips.EntriesForPeer(
|
|
||||||
peer,
|
|
||||||
func(prefix netip.Prefix) bool {
|
|
||||||
sendf("allowed_ip=%s", prefix.String())
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -191,7 +186,10 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if line == "" {
|
if line == "" {
|
||||||
// Blank line means terminate operation.
|
// Blank line means terminate operation.
|
||||||
device.handlePostConfig(&tempASecCfg)
|
err := device.handlePostConfig(&tempASecCfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
peer.handlePostConfig()
|
peer.handlePostConfig()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -227,7 +225,10 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
device.handlePostConfig(&tempASecCfg)
|
err = device.handlePostConfig(&tempASecCfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
peer.handlePostConfig()
|
peer.handlePostConfig()
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
|
@ -242,7 +243,7 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
||||||
var sk NoisePrivateKey
|
var sk NoisePrivateKey
|
||||||
err := sk.FromMaybeZeroHex(value)
|
err := sk.FromMaybeZeroHex(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"failed to set private_key: %w",err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("UAPI: Updating private key")
|
device.log.Verbosef("UAPI: Updating private key")
|
||||||
device.SetPrivateKey(sk)
|
device.SetPrivateKey(sk)
|
||||||
|
@ -250,7 +251,7 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
||||||
case "listen_port":
|
case "listen_port":
|
||||||
port, err := strconv.ParseUint(value, 10, 16)
|
port, err := strconv.ParseUint(value, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"failed to parse listen_port: %w",err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// update port and rebind
|
// update port and rebind
|
||||||
|
@ -261,7 +262,7 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
||||||
device.net.Unlock()
|
device.net.Unlock()
|
||||||
|
|
||||||
if err := device.BindUpdate(); err != nil {
|
if err := device.BindUpdate(); err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorPortInUse,"failed to set listen_port: %w",err)
|
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "fwmark":
|
case "fwmark":
|
||||||
|
@ -272,12 +273,12 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
||||||
|
|
||||||
device.log.Verbosef("UAPI: Updating fwmark")
|
device.log.Verbosef("UAPI: Updating fwmark")
|
||||||
if err := device.BindSetMark(uint32(mark)); err != nil {
|
if err := device.BindSetMark(uint32(mark)); err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorPortInUse,"failed to update fwmark: %w", err)
|
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "replace_peers":
|
case "replace_peers":
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"failed to set replace_peers, invalid value: %v", value)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("UAPI: Removing all peers")
|
device.log.Verbosef("UAPI: Removing all peers")
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
|
@ -287,65 +288,80 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
|
||||||
}
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
||||||
tempASecCfg.junkPacketCount = junkPacketCount
|
tempASecCfg.junkPacketCount = junkPacketCount
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
case "jmin":
|
case "jmin":
|
||||||
junkPacketMinSize, err := strconv.Atoi(value)
|
junkPacketMinSize, err := strconv.Atoi(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse junk_packet_min_size %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
|
||||||
}
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
||||||
tempASecCfg.junkPacketMinSize = junkPacketMinSize
|
tempASecCfg.junkPacketMinSize = junkPacketMinSize
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
case "jmax":
|
case "jmax":
|
||||||
junkPacketMaxSize, err := strconv.Atoi(value)
|
junkPacketMaxSize, err := strconv.Atoi(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse junk_packet_max_size %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
|
||||||
}
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
||||||
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
|
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
case "s1":
|
case "s1":
|
||||||
initPacketJunkSize, err := strconv.Atoi(value)
|
initPacketJunkSize, err := strconv.Atoi(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse init_packet_junk_size %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
|
||||||
}
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
||||||
tempASecCfg.initPacketJunkSize = initPacketJunkSize
|
tempASecCfg.initPacketJunkSize = initPacketJunkSize
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
case "s2":
|
case "s2":
|
||||||
responsePacketJunkSize, err := strconv.Atoi(value)
|
responsePacketJunkSize, err := strconv.Atoi(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse response_packet_junk_size %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
|
||||||
}
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
||||||
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
|
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
case "h1":
|
case "h1":
|
||||||
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse init_packet_magic_header %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
|
||||||
}
|
}
|
||||||
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
case "h2":
|
case "h2":
|
||||||
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse response_packet_magic_header %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
|
||||||
}
|
}
|
||||||
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
case "h3":
|
case "h3":
|
||||||
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse underload_packet_magic_header %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
|
||||||
}
|
}
|
||||||
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
case "h4":
|
case "h4":
|
||||||
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse transport_packet_magic_header %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
|
||||||
}
|
}
|
||||||
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
||||||
|
tempASecCfg.isSet = true
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid,"invalid UAPI device key: %v",key)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -364,8 +380,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if peer.created {
|
if peer.created {
|
||||||
peer.disableRoaming = peer.device.net.brokenRoaming &&
|
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
||||||
peer.endpoint != nil
|
|
||||||
}
|
}
|
||||||
if peer.device.isUp() {
|
if peer.device.isUp() {
|
||||||
peer.Start()
|
peer.Start()
|
||||||
|
@ -452,11 +467,11 @@ func (device *Device) handlePeerLine(
|
||||||
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
||||||
endpoint, err := device.net.bind.ParseEndpoint(value)
|
endpoint, err := device.net.bind.ParseEndpoint(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||||
}
|
}
|
||||||
peer.Lock()
|
peer.endpoint.Lock()
|
||||||
defer peer.Unlock()
|
defer peer.endpoint.Unlock()
|
||||||
peer.endpoint = endpoint
|
peer.endpoint.val = endpoint
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
crand "crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func appendJunk(writer *bytes.Buffer, size int) error {
|
|
||||||
headerJunk, err := randomJunkWithSize(size)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create header junk: %v", err)
|
|
||||||
}
|
|
||||||
_, err = writer.Write(headerJunk)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write header junk: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomJunkWithSize(size int) ([]byte, error) {
|
|
||||||
junk := make([]byte, size)
|
|
||||||
_, err := crand.Read(junk)
|
|
||||||
return junk, err
|
|
||||||
}
|
|
|
@ -1,27 +0,0 @@
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_randomJunktWithSize(t *testing.T) {
|
|
||||||
junk, err := randomJunkWithSize(30)
|
|
||||||
fmt.Println(string(junk), len(junk), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_appendJunk(t *testing.T) {
|
|
||||||
t.Run("", func(t *testing.T) {
|
|
||||||
s := "apple"
|
|
||||||
buffer := bytes.NewBuffer([]byte(s))
|
|
||||||
err := appendJunk(buffer, 30)
|
|
||||||
if err != nil &&
|
|
||||||
buffer.Len() != len(s)+30 {
|
|
||||||
t.Errorf("appendWithJunk() size don't match")
|
|
||||||
}
|
|
||||||
read := make([]byte, 50)
|
|
||||||
buffer.Read(read)
|
|
||||||
fmt.Println(string(read))
|
|
||||||
})
|
|
||||||
}
|
|
16
go.mod
16
go.mod
|
@ -1,17 +1,17 @@
|
||||||
module github.com/amnezia-vpn/amnezia-wg
|
module github.com/amnezia-vpn/amneziawg-go
|
||||||
|
|
||||||
go 1.20
|
go 1.24
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/tevino/abool/v2 v2.1.0
|
github.com/tevino/abool/v2 v2.1.0
|
||||||
golang.org/x/crypto v0.6.0
|
golang.org/x/crypto v0.36.0
|
||||||
golang.org/x/net v0.7.0
|
golang.org/x/net v0.37.0
|
||||||
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
|
golang.org/x/sys v0.31.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
|
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/btree v1.0.1 // indirect
|
github.com/google/btree v1.1.3 // indirect
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
|
golang.org/x/time v0.9.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
28
go.sum
28
go.sum
|
@ -1,16 +1,20 @@
|
||||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||||
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
|
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
|
||||||
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
|
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
|
||||||
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
|
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||||
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
|
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
||||||
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
||||||
|
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY=
|
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ=
|
||||||
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA=
|
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=
|
||||||
|
|
|
@ -20,7 +20,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/ipc/namedpipe"
|
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ const (
|
||||||
|
|
||||||
// socketDirectory is variable because it is modified by a linker
|
// socketDirectory is variable because it is modified by a linker
|
||||||
// flag in wireguard-android.
|
// flag in wireguard-android.
|
||||||
var socketDirectory = "/var/run/wireguard"
|
var socketDirectory = "/var/run/amneziawg"
|
||||||
|
|
||||||
func sockPath(iface string) string {
|
func sockPath(iface string) string {
|
||||||
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
||||||
|
|
|
@ -8,7 +8,7 @@ package ipc
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/ipc/namedpipe"
|
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ func init() {
|
||||||
func UAPIListen(name string) (net.Listener, error) {
|
func UAPIListen(name string) (net.Listener, error) {
|
||||||
listener, err := (&namedpipe.ListenConfig{
|
listener, err := (&namedpipe.ListenConfig{
|
||||||
SecurityDescriptor: UAPISecurityDescriptor,
|
SecurityDescriptor: UAPISecurityDescriptor,
|
||||||
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
|
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
30
main.go
30
main.go
|
@ -14,10 +14,10 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/device"
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/ipc"
|
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,20 +46,20 @@ func warning() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
|
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
|
||||||
fmt.Fprintln(os.Stderr, "│ │")
|
fmt.Fprintln(os.Stderr, "│ │")
|
||||||
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
|
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
|
||||||
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
|
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
|
||||||
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
||||||
fmt.Fprintln(os.Stderr, "│ please visit: │")
|
fmt.Fprintln(os.Stderr, "│ please visit: │")
|
||||||
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
|
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
|
||||||
fmt.Fprintln(os.Stderr, "│ │")
|
fmt.Fprintln(os.Stderr, "│ │")
|
||||||
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
|
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
||||||
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
|
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,7 +145,7 @@ func main() {
|
||||||
fmt.Sprintf("(%s) ", interfaceName),
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.Verbosef("Starting wireguard-go version %s", Version)
|
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to create TUN device: %v", err)
|
logger.Errorf("Failed to create TUN device: %v", err)
|
||||||
|
|
|
@ -12,11 +12,11 @@ import (
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/device"
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/ipc"
|
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -30,13 +30,13 @@ func main() {
|
||||||
}
|
}
|
||||||
interfaceName := os.Args[1]
|
interfaceName := os.Args[1]
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
|
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
|
||||||
|
|
||||||
logger := device.NewLogger(
|
logger := device.NewLogger(
|
||||||
device.LogLevelVerbose,
|
device.LogLevelVerbose,
|
||||||
fmt.Sprintf("(%s) ", interfaceName),
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
)
|
)
|
||||||
logger.Verbosef("Starting wireguard-go version %s", Version)
|
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
||||||
|
|
||||||
tun, err := tun.CreateTUN(interfaceName, 0)
|
tun, err := tun.CreateTUN(interfaceName, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
100
tun/checksum.go
100
tun/checksum.go
|
@ -3,23 +3,99 @@ package tun
|
||||||
import "encoding/binary"
|
import "encoding/binary"
|
||||||
|
|
||||||
// TODO: Explore SIMD and/or other assembly optimizations.
|
// TODO: Explore SIMD and/or other assembly optimizations.
|
||||||
|
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
|
||||||
func checksumNoFold(b []byte, initial uint64) uint64 {
|
func checksumNoFold(b []byte, initial uint64) uint64 {
|
||||||
ac := initial
|
ac := initial
|
||||||
i := 0
|
|
||||||
n := len(b)
|
for len(b) >= 128 {
|
||||||
for n >= 4 {
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
n -= 4
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
i += 4
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
|
||||||
|
b = b[128:]
|
||||||
}
|
}
|
||||||
for n >= 2 {
|
if len(b) >= 64 {
|
||||||
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
n -= 2
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
i += 2
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||||
|
b = b[64:]
|
||||||
}
|
}
|
||||||
if n == 1 {
|
if len(b) >= 32 {
|
||||||
ac += uint64(b[i]) << 8
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||||
|
b = b[32:]
|
||||||
}
|
}
|
||||||
|
if len(b) >= 16 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||||
|
b = b[16:]
|
||||||
|
}
|
||||||
|
if len(b) >= 8 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||||
|
b = b[8:]
|
||||||
|
}
|
||||||
|
if len(b) >= 4 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint32(b))
|
||||||
|
b = b[4:]
|
||||||
|
}
|
||||||
|
if len(b) >= 2 {
|
||||||
|
ac += uint64(binary.BigEndian.Uint16(b))
|
||||||
|
b = b[2:]
|
||||||
|
}
|
||||||
|
if len(b) == 1 {
|
||||||
|
ac += uint64(b[0]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
return ac
|
return ac
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
35
tun/checksum_test.go
Normal file
35
tun/checksum_test.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkChecksum(b *testing.B) {
|
||||||
|
lengths := []int{
|
||||||
|
64,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1500,
|
||||||
|
2048,
|
||||||
|
4096,
|
||||||
|
8192,
|
||||||
|
9000,
|
||||||
|
9001,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, length := range lengths {
|
||||||
|
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
|
||||||
|
buf := make([]byte, length)
|
||||||
|
rng := rand.New(rand.NewSource(1))
|
||||||
|
rng.Read(buf)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
checksum(buf, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -13,9 +13,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/device"
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun/netstack"
|
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
@ -14,9 +14,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/device"
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun/netstack"
|
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
@ -17,9 +17,9 @@ import (
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/device"
|
"github.com/amnezia-vpn/amneziawg-go/device"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun/netstack"
|
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
@ -22,10 +22,10 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
|
|
||||||
"golang.org/x/net/dns/dnsmessage"
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
"gvisor.dev/gvisor/pkg/bufferv2"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
@ -43,7 +43,7 @@ type netTun struct {
|
||||||
ep *channel.Endpoint
|
ep *channel.Endpoint
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
events chan tun.Event
|
events chan tun.Event
|
||||||
incomingPacket chan *bufferv2.View
|
incomingPacket chan *buffer.View
|
||||||
mtu int
|
mtu int
|
||||||
dnsServers []netip.Addr
|
dnsServers []netip.Addr
|
||||||
hasV4, hasV6 bool
|
hasV4, hasV6 bool
|
||||||
|
@ -61,7 +61,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
|
||||||
ep: channel.New(1024, uint32(mtu), ""),
|
ep: channel.New(1024, uint32(mtu), ""),
|
||||||
stack: stack.New(opts),
|
stack: stack.New(opts),
|
||||||
events: make(chan tun.Event, 10),
|
events: make(chan tun.Event, 10),
|
||||||
incomingPacket: make(chan *bufferv2.View),
|
incomingPacket: make(chan *buffer.View),
|
||||||
dnsServers: dnsServers,
|
dnsServers: dnsServers,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
|
||||||
}
|
}
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: protoNumber,
|
Protocol: protoNumber,
|
||||||
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
|
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
|
||||||
}
|
}
|
||||||
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
||||||
if tcpipErr != nil {
|
if tcpipErr != nil {
|
||||||
|
@ -140,7 +140,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
|
||||||
switch packet[0] >> 4 {
|
switch packet[0] >> 4 {
|
||||||
case 4:
|
case 4:
|
||||||
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
||||||
|
@ -198,7 +198,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ
|
||||||
}
|
}
|
||||||
return tcpip.FullAddress{
|
return tcpip.FullAddress{
|
||||||
NIC: 1,
|
NIC: 1,
|
||||||
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
|
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
|
||||||
Port: endpoint.Port(),
|
Port: endpoint.Port(),
|
||||||
}, protoNumber
|
}, protoNumber
|
||||||
}
|
}
|
||||||
|
@ -453,7 +453,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
|
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))
|
remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
|
||||||
return res.Count, &PingAddr{remoteAddr}, nil
|
return res.Count, &PingAddr{remoteAddr}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
993
tun/offload_linux.go
Normal file
993
tun/offload_linux.go
Normal file
|
@ -0,0 +1,993 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const tcpFlagsOffset = 13
|
||||||
|
|
||||||
|
const (
|
||||||
|
tcpFlagFIN uint8 = 0x01
|
||||||
|
tcpFlagPSH uint8 = 0x08
|
||||||
|
tcpFlagACK uint8 = 0x10
|
||||||
|
)
|
||||||
|
|
||||||
|
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
||||||
|
// kernel symbol is virtio_net_hdr.
|
||||||
|
type virtioNetHdr struct {
|
||||||
|
flags uint8
|
||||||
|
gsoType uint8
|
||||||
|
hdrLen uint16
|
||||||
|
gsoSize uint16
|
||||||
|
csumStart uint16
|
||||||
|
csumOffset uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtioNetHdr) decode(b []byte) error {
|
||||||
|
if len(b) < virtioNetHdrLen {
|
||||||
|
return io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *virtioNetHdr) encode(b []byte) error {
|
||||||
|
if len(b) < virtioNetHdrLen {
|
||||||
|
return io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
||||||
|
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
||||||
|
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
||||||
|
)
|
||||||
|
|
||||||
|
// tcpFlowKey represents the key for a TCP flow.
|
||||||
|
type tcpFlowKey struct {
|
||||||
|
srcAddr, dstAddr [16]byte
|
||||||
|
srcPort, dstPort uint16
|
||||||
|
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
||||||
|
isV6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
|
||||||
|
type tcpGROTable struct {
|
||||||
|
itemsByFlow map[tcpFlowKey][]tcpGROItem
|
||||||
|
itemsPool [][]tcpGROItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTCPGROTable() *tcpGROTable {
|
||||||
|
t := &tcpGROTable{
|
||||||
|
itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
|
||||||
|
itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
|
||||||
|
}
|
||||||
|
for i := range t.itemsPool {
|
||||||
|
t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
|
||||||
|
key := tcpFlowKey{}
|
||||||
|
addrSize := dstAddrOffset - srcAddrOffset
|
||||||
|
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
|
||||||
|
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
|
||||||
|
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
||||||
|
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
||||||
|
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
||||||
|
key.isV6 = addrSize == 16
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||||
|
// returning the packets found for the flow, or inserting a new one if none
|
||||||
|
// is found.
|
||||||
|
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
||||||
|
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||||
|
items, ok := t.itemsByFlow[key]
|
||||||
|
if ok {
|
||||||
|
return items, ok
|
||||||
|
}
|
||||||
|
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||||
|
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert an item in the table for the provided packet and packet metadata.
|
||||||
|
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
||||||
|
key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
||||||
|
item := tcpGROItem{
|
||||||
|
key: key,
|
||||||
|
bufsIndex: uint16(bufsIndex),
|
||||||
|
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
||||||
|
iphLen: uint8(tcphOffset),
|
||||||
|
tcphLen: uint8(tcphLen),
|
||||||
|
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
||||||
|
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
||||||
|
}
|
||||||
|
items, ok := t.itemsByFlow[key]
|
||||||
|
if !ok {
|
||||||
|
items = t.newItems()
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
t.itemsByFlow[key] = items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
||||||
|
items, _ := t.itemsByFlow[item.key]
|
||||||
|
items[i] = item
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
|
||||||
|
items, _ := t.itemsByFlow[key]
|
||||||
|
items = append(items[:i], items[i+1:]...)
|
||||||
|
t.itemsByFlow[key] = items
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
||||||
|
// of a GRO evaluation across a vector of packets.
|
||||||
|
type tcpGROItem struct {
|
||||||
|
key tcpFlowKey
|
||||||
|
sentSeq uint32 // the sequence number
|
||||||
|
bufsIndex uint16 // the index into the original bufs slice
|
||||||
|
numMerged uint16 // the number of packets merged into this item
|
||||||
|
gsoSize uint16 // payload size
|
||||||
|
iphLen uint8 // ip header len
|
||||||
|
tcphLen uint8 // tcp header len
|
||||||
|
pshSet bool // psh flag is set
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) newItems() []tcpGROItem {
|
||||||
|
var items []tcpGROItem
|
||||||
|
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tcpGROTable) reset() {
|
||||||
|
for k, items := range t.itemsByFlow {
|
||||||
|
items = items[:0]
|
||||||
|
t.itemsPool = append(t.itemsPool, items)
|
||||||
|
delete(t.itemsByFlow, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpFlowKey represents the key for a UDP flow.
|
||||||
|
type udpFlowKey struct {
|
||||||
|
srcAddr, dstAddr [16]byte
|
||||||
|
srcPort, dstPort uint16
|
||||||
|
isV6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
|
||||||
|
type udpGROTable struct {
|
||||||
|
itemsByFlow map[udpFlowKey][]udpGROItem
|
||||||
|
itemsPool [][]udpGROItem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPGROTable() *udpGROTable {
|
||||||
|
u := &udpGROTable{
|
||||||
|
itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
|
||||||
|
itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
|
||||||
|
}
|
||||||
|
for i := range u.itemsPool {
|
||||||
|
u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
|
||||||
|
}
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
|
||||||
|
key := udpFlowKey{}
|
||||||
|
addrSize := dstAddrOffset - srcAddrOffset
|
||||||
|
copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
|
||||||
|
copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
|
||||||
|
key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
|
||||||
|
key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
|
||||||
|
key.isV6 = addrSize == 16
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
||||||
|
// returning the packets found for the flow, or inserting a new one if none
|
||||||
|
// is found.
|
||||||
|
func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
|
||||||
|
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
|
||||||
|
items, ok := u.itemsByFlow[key]
|
||||||
|
if ok {
|
||||||
|
return items, ok
|
||||||
|
}
|
||||||
|
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
||||||
|
u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert an item in the table for the provided packet and packet metadata.
|
||||||
|
func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
|
||||||
|
key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
|
||||||
|
item := udpGROItem{
|
||||||
|
key: key,
|
||||||
|
bufsIndex: uint16(bufsIndex),
|
||||||
|
gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
|
||||||
|
iphLen: uint8(udphOffset),
|
||||||
|
cSumKnownInvalid: cSumKnownInvalid,
|
||||||
|
}
|
||||||
|
items, ok := u.itemsByFlow[key]
|
||||||
|
if !ok {
|
||||||
|
items = u.newItems()
|
||||||
|
}
|
||||||
|
items = append(items, item)
|
||||||
|
u.itemsByFlow[key] = items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpGROTable) updateAt(item udpGROItem, i int) {
|
||||||
|
items, _ := u.itemsByFlow[item.key]
|
||||||
|
items[i] = item
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
|
||||||
|
// of a GRO evaluation across a vector of packets.
|
||||||
|
type udpGROItem struct {
|
||||||
|
key udpFlowKey
|
||||||
|
bufsIndex uint16 // the index into the original bufs slice
|
||||||
|
numMerged uint16 // the number of packets merged into this item
|
||||||
|
gsoSize uint16 // payload size
|
||||||
|
iphLen uint8 // ip header len
|
||||||
|
cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpGROTable) newItems() []udpGROItem {
|
||||||
|
var items []udpGROItem
|
||||||
|
items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpGROTable) reset() {
|
||||||
|
for k, items := range u.itemsByFlow {
|
||||||
|
items = items[:0]
|
||||||
|
u.itemsPool = append(u.itemsPool, items)
|
||||||
|
delete(u.itemsByFlow, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// canCoalesce represents the outcome of checking if two TCP packets are
|
||||||
|
// candidates for coalescing.
|
||||||
|
type canCoalesce int
|
||||||
|
|
||||||
|
const (
|
||||||
|
coalescePrepend canCoalesce = -1
|
||||||
|
coalesceUnavailable canCoalesce = 0
|
||||||
|
coalesceAppend canCoalesce = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
|
||||||
|
// meet all requirements to be merged as part of a GRO operation, otherwise it
|
||||||
|
// returns false.
|
||||||
|
func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
|
||||||
|
if len(pktA) < 9 || len(pktB) < 9 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pktA[0]>>4 == 6 {
|
||||||
|
if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
|
||||||
|
// cannot coalesce with unequal Traffic class values
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pktA[7] != pktB[7] {
|
||||||
|
// cannot coalesce with unequal Hop limit values
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if pktA[1] != pktB[1] {
|
||||||
|
// cannot coalesce with unequal ToS values
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pktA[6]>>5 != pktB[6]>>5 {
|
||||||
|
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
||||||
|
// further up the stack.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pktA[8] != pktB[8] {
|
||||||
|
// cannot coalesce with unequal TTL values
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||||
|
// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
|
||||||
|
// packets involved in the current GRO evaluation. bufsOffset is the offset at
|
||||||
|
// which packet data begins within bufs.
|
||||||
|
func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||||
|
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||||
|
if !ipHeadersCanCoalesce(pkt, pktTarget) {
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
|
||||||
|
// A smaller than gsoSize packet has been appended previously.
|
||||||
|
// Nothing can come after a smaller packet on the end.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize > item.gsoSize {
|
||||||
|
// We cannot have a larger packet following a smaller one.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
return coalesceAppend
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
||||||
|
// described by item. This function makes considerations that match the kernel's
|
||||||
|
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
||||||
|
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
||||||
|
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
||||||
|
if tcphLen != item.tcphLen {
|
||||||
|
// cannot coalesce with unequal tcp options len
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if tcphLen > 20 {
|
||||||
|
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
||||||
|
// cannot coalesce with unequal tcp options
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ipHeadersCanCoalesce(pkt, pktTarget) {
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
// seq adjacency
|
||||||
|
lhsLen := item.gsoSize
|
||||||
|
lhsLen += item.numMerged * item.gsoSize
|
||||||
|
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
||||||
|
if item.pshSet {
|
||||||
|
// We cannot append to a segment that has the PSH flag set, PSH
|
||||||
|
// can only be set on the final segment in a reassembled group.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
||||||
|
// A smaller than gsoSize packet has been appended previously.
|
||||||
|
// Nothing can come after a smaller packet on the end.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize > item.gsoSize {
|
||||||
|
// We cannot have a larger packet following a smaller one.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
return coalesceAppend
|
||||||
|
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
||||||
|
if pshSet {
|
||||||
|
// We cannot prepend with a segment that has the PSH flag set, PSH
|
||||||
|
// can only be set on the final segment in a reassembled group.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize < item.gsoSize {
|
||||||
|
// We cannot have a larger packet following a smaller one.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
||||||
|
// There's at least one previous merge, and we're larger than all
|
||||||
|
// previous. This would put multiple smaller packets on the end.
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
return coalescePrepend
|
||||||
|
}
|
||||||
|
return coalesceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
|
||||||
|
srcAddrAt := ipv4SrcAddrOffset
|
||||||
|
addrSize := 4
|
||||||
|
if isV6 {
|
||||||
|
srcAddrAt = ipv6SrcAddrOffset
|
||||||
|
addrSize = 16
|
||||||
|
}
|
||||||
|
lenForPseudo := uint16(len(pkt) - int(iphLen))
|
||||||
|
cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
|
||||||
|
return ^checksum(pkt[iphLen:], cSum) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// coalesceResult represents the result of attempting to coalesce two TCP
|
||||||
|
// packets.
|
||||||
|
type coalesceResult int
|
||||||
|
|
||||||
|
const (
|
||||||
|
coalesceInsufficientCap coalesceResult = iota
|
||||||
|
coalescePSHEnding
|
||||||
|
coalesceItemInvalidCSum
|
||||||
|
coalescePktInvalidCSum
|
||||||
|
coalesceSuccess
|
||||||
|
)
|
||||||
|
|
||||||
|
// coalesceUDPPackets attempts to coalesce pkt with the packet described by
|
||||||
|
// item, and returns the outcome.
|
||||||
|
func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||||
|
pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
|
||||||
|
headersLen := item.iphLen + udphLen
|
||||||
|
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||||
|
|
||||||
|
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||||
|
// We don't want to allocate a new underlying array if capacity is
|
||||||
|
// too small.
|
||||||
|
return coalesceInsufficientCap
|
||||||
|
}
|
||||||
|
if item.numMerged == 0 {
|
||||||
|
if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
|
||||||
|
return coalesceItemInvalidCSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
|
||||||
|
return coalescePktInvalidCSum
|
||||||
|
}
|
||||||
|
extendBy := len(pkt) - int(headersLen)
|
||||||
|
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||||
|
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||||
|
|
||||||
|
item.numMerged++
|
||||||
|
return coalesceSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
||||||
|
// item, and returns the outcome. This function may swap bufs elements in the
|
||||||
|
// event of a prepend as item's bufs index is already being tracked for writing
|
||||||
|
// to a Device.
|
||||||
|
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
||||||
|
var pktHead []byte // the packet that will end up at the front
|
||||||
|
headersLen := item.iphLen + item.tcphLen
|
||||||
|
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
||||||
|
|
||||||
|
// Copy data
|
||||||
|
if mode == coalescePrepend {
|
||||||
|
pktHead = pkt
|
||||||
|
if cap(pkt)-bufsOffset < coalescedLen {
|
||||||
|
// We don't want to allocate a new underlying array if capacity is
|
||||||
|
// too small.
|
||||||
|
return coalesceInsufficientCap
|
||||||
|
}
|
||||||
|
if pshSet {
|
||||||
|
return coalescePSHEnding
|
||||||
|
}
|
||||||
|
if item.numMerged == 0 {
|
||||||
|
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||||
|
return coalesceItemInvalidCSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||||
|
return coalescePktInvalidCSum
|
||||||
|
}
|
||||||
|
item.sentSeq = seq
|
||||||
|
extendBy := coalescedLen - len(pktHead)
|
||||||
|
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
||||||
|
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
||||||
|
// Flip the slice headers in bufs as part of prepend. The index of item
|
||||||
|
// is already being tracked for writing.
|
||||||
|
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
||||||
|
} else {
|
||||||
|
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
||||||
|
if cap(pktHead)-bufsOffset < coalescedLen {
|
||||||
|
// We don't want to allocate a new underlying array if capacity is
|
||||||
|
// too small.
|
||||||
|
return coalesceInsufficientCap
|
||||||
|
}
|
||||||
|
if item.numMerged == 0 {
|
||||||
|
if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||||
|
return coalesceItemInvalidCSum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
|
||||||
|
return coalescePktInvalidCSum
|
||||||
|
}
|
||||||
|
if pshSet {
|
||||||
|
// We are appending a segment with PSH set.
|
||||||
|
item.pshSet = pshSet
|
||||||
|
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
||||||
|
}
|
||||||
|
extendBy := len(pkt) - int(headersLen)
|
||||||
|
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
||||||
|
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
if gsoSize > item.gsoSize {
|
||||||
|
item.gsoSize = gsoSize
|
||||||
|
}
|
||||||
|
|
||||||
|
item.numMerged++
|
||||||
|
return coalesceSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4FlagMoreFragments uint8 = 0x20
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4SrcAddrOffset = 12
|
||||||
|
ipv6SrcAddrOffset = 8
|
||||||
|
maxUint16 = 1<<16 - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type groResult int
|
||||||
|
|
||||||
|
const (
|
||||||
|
groResultNoop groResult = iota
|
||||||
|
groResultTableInsert
|
||||||
|
groResultCoalesced
|
||||||
|
)
|
||||||
|
|
||||||
|
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
||||||
|
// existing packets tracked in table. It returns a groResultNoop when no
|
||||||
|
// action was taken, groResultTableInsert when the evaluated packet was
|
||||||
|
// inserted into table, and groResultCoalesced when the evaluated packet was
|
||||||
|
// coalesced with another packet in table.
|
||||||
|
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
|
||||||
|
pkt := bufs[pktI][offset:]
|
||||||
|
if len(pkt) > maxUint16 {
|
||||||
|
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||||
|
if isV6 {
|
||||||
|
iphLen = 40
|
||||||
|
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||||
|
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||||
|
if totalLen != len(pkt) {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
||||||
|
if tcphLen < 20 || tcphLen > 60 {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen+tcphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
if !isV6 {
|
||||||
|
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||||
|
// no GRO support for fragmented segments for now
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
||||||
|
var pshSet bool
|
||||||
|
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
||||||
|
if tcpFlags != tcpFlagACK {
|
||||||
|
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
pshSet = true
|
||||||
|
}
|
||||||
|
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
||||||
|
// not a candidate if payload len is 0
|
||||||
|
if gsoSize < 1 {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
||||||
|
srcAddrOffset := ipv4SrcAddrOffset
|
||||||
|
addrLen := 4
|
||||||
|
if isV6 {
|
||||||
|
srcAddrOffset = ipv6SrcAddrOffset
|
||||||
|
addrLen = 16
|
||||||
|
}
|
||||||
|
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||||
|
if !existing {
|
||||||
|
return groResultTableInsert
|
||||||
|
}
|
||||||
|
for i := len(items) - 1; i >= 0; i-- {
|
||||||
|
// In the best case of packets arriving in order iterating in reverse is
|
||||||
|
// more efficient if there are multiple items for a given flow. This
|
||||||
|
// also enables a natural table.deleteAt() in the
|
||||||
|
// coalesceItemInvalidCSum case without the need for index tracking.
|
||||||
|
// This algorithm makes a best effort to coalesce in the event of
|
||||||
|
// unordered packets, where pkt may land anywhere in items from a
|
||||||
|
// sequence number perspective, however once an item is inserted into
|
||||||
|
// the table it is never compared across other items later.
|
||||||
|
item := items[i]
|
||||||
|
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
||||||
|
if can != coalesceUnavailable {
|
||||||
|
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
||||||
|
switch result {
|
||||||
|
case coalesceSuccess:
|
||||||
|
table.updateAt(item, i)
|
||||||
|
return groResultCoalesced
|
||||||
|
case coalesceItemInvalidCSum:
|
||||||
|
// delete the item with an invalid csum
|
||||||
|
table.deleteAt(item.key, i)
|
||||||
|
case coalescePktInvalidCSum:
|
||||||
|
// no point in inserting an item that we can't coalesce
|
||||||
|
return groResultNoop
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// failed to coalesce with any other packets; store the item in the flow
|
||||||
|
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
||||||
|
return groResultTableInsert
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
|
||||||
|
// metadata found in table.
|
||||||
|
func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
|
||||||
|
for _, items := range table.itemsByFlow {
|
||||||
|
for _, item := range items {
|
||||||
|
if item.numMerged > 0 {
|
||||||
|
hdr := virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||||
|
hdrLen: uint16(item.iphLen + item.tcphLen),
|
||||||
|
gsoSize: item.gsoSize,
|
||||||
|
csumStart: uint16(item.iphLen),
|
||||||
|
csumOffset: 16,
|
||||||
|
}
|
||||||
|
pkt := bufs[item.bufsIndex][offset:]
|
||||||
|
|
||||||
|
// Recalculate the total len (IPv4) or payload len (IPv6).
|
||||||
|
// Recalculate the (IPv4) header checksum.
|
||||||
|
if item.key.isV6 {
|
||||||
|
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
||||||
|
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
||||||
|
} else {
|
||||||
|
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
||||||
|
pkt[10], pkt[11] = 0, 0
|
||||||
|
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
||||||
|
iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
||||||
|
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
||||||
|
}
|
||||||
|
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the pseudo header checksum and place it at the TCP
|
||||||
|
// checksum offset. Downstream checksum offloading will combine
|
||||||
|
// this with computation of the tcp header and payload checksum.
|
||||||
|
addrLen := 4
|
||||||
|
addrOffset := ipv4SrcAddrOffset
|
||||||
|
if item.key.isV6 {
|
||||||
|
addrLen = 16
|
||||||
|
addrOffset = ipv6SrcAddrOffset
|
||||||
|
}
|
||||||
|
srcAddrAt := offset + addrOffset
|
||||||
|
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||||
|
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||||
|
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
||||||
|
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
||||||
|
} else {
|
||||||
|
hdr := virtioNetHdr{}
|
||||||
|
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
|
||||||
|
// metadata found in table.
|
||||||
|
func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
|
||||||
|
for _, items := range table.itemsByFlow {
|
||||||
|
for _, item := range items {
|
||||||
|
if item.numMerged > 0 {
|
||||||
|
hdr := virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
||||||
|
hdrLen: uint16(item.iphLen + udphLen),
|
||||||
|
gsoSize: item.gsoSize,
|
||||||
|
csumStart: uint16(item.iphLen),
|
||||||
|
csumOffset: 6,
|
||||||
|
}
|
||||||
|
pkt := bufs[item.bufsIndex][offset:]
|
||||||
|
|
||||||
|
// Recalculate the total len (IPv4) or payload len (IPv6).
|
||||||
|
// Recalculate the (IPv4) header checksum.
|
||||||
|
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
|
||||||
|
if item.key.isV6 {
|
||||||
|
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
|
||||||
|
} else {
|
||||||
|
pkt[10], pkt[11] = 0, 0
|
||||||
|
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
|
||||||
|
iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
|
||||||
|
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
|
||||||
|
}
|
||||||
|
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recalculate the UDP len field value
|
||||||
|
binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
|
||||||
|
|
||||||
|
// Calculate the pseudo header checksum and place it at the UDP
|
||||||
|
// checksum offset. Downstream checksum offloading will combine
|
||||||
|
// this with computation of the udp header and payload checksum.
|
||||||
|
addrLen := 4
|
||||||
|
addrOffset := ipv4SrcAddrOffset
|
||||||
|
if item.key.isV6 {
|
||||||
|
addrLen = 16
|
||||||
|
addrOffset = ipv6SrcAddrOffset
|
||||||
|
}
|
||||||
|
srcAddrAt := offset + addrOffset
|
||||||
|
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
||||||
|
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
||||||
|
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
|
||||||
|
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
||||||
|
} else {
|
||||||
|
hdr := virtioNetHdr{}
|
||||||
|
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type groCandidateType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
notGROCandidate groCandidateType = iota
|
||||||
|
tcp4GROCandidate
|
||||||
|
tcp6GROCandidate
|
||||||
|
udp4GROCandidate
|
||||||
|
udp6GROCandidate
|
||||||
|
)
|
||||||
|
|
||||||
|
func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
|
||||||
|
if len(b) < 28 {
|
||||||
|
return notGROCandidate
|
||||||
|
}
|
||||||
|
if b[0]>>4 == 4 {
|
||||||
|
if b[0]&0x0F != 5 {
|
||||||
|
// IPv4 packets w/IP options do not coalesce
|
||||||
|
return notGROCandidate
|
||||||
|
}
|
||||||
|
if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
|
||||||
|
return tcp4GROCandidate
|
||||||
|
}
|
||||||
|
if b[9] == unix.IPPROTO_UDP && canUDPGRO {
|
||||||
|
return udp4GROCandidate
|
||||||
|
}
|
||||||
|
} else if b[0]>>4 == 6 {
|
||||||
|
if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
|
||||||
|
return tcp6GROCandidate
|
||||||
|
}
|
||||||
|
if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
|
||||||
|
return udp6GROCandidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return notGROCandidate
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
udphLen = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
|
||||||
|
// existing packets tracked in table. It returns a groResultNoop when no
|
||||||
|
// action was taken, groResultTableInsert when the evaluated packet was
|
||||||
|
// inserted into table, and groResultCoalesced when the evaluated packet was
|
||||||
|
// coalesced with another packet in table.
|
||||||
|
func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
|
||||||
|
pkt := bufs[pktI][offset:]
|
||||||
|
if len(pkt) > maxUint16 {
|
||||||
|
// A valid IPv4 or IPv6 packet will never exceed this.
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
iphLen := int((pkt[0] & 0x0F) * 4)
|
||||||
|
if isV6 {
|
||||||
|
iphLen = 40
|
||||||
|
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
||||||
|
if ipv6HPayloadLen != len(pkt)-iphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
||||||
|
if totalLen != len(pkt) {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
if len(pkt) < iphLen+udphLen {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
if !isV6 {
|
||||||
|
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
||||||
|
// no GRO support for fragmented segments for now
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gsoSize := uint16(len(pkt) - udphLen - iphLen)
|
||||||
|
// not a candidate if payload len is 0
|
||||||
|
if gsoSize < 1 {
|
||||||
|
return groResultNoop
|
||||||
|
}
|
||||||
|
srcAddrOffset := ipv4SrcAddrOffset
|
||||||
|
addrLen := 4
|
||||||
|
if isV6 {
|
||||||
|
srcAddrOffset = ipv6SrcAddrOffset
|
||||||
|
addrLen = 16
|
||||||
|
}
|
||||||
|
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
|
||||||
|
if !existing {
|
||||||
|
return groResultTableInsert
|
||||||
|
}
|
||||||
|
// With UDP we only check the last item, otherwise we could reorder packets
|
||||||
|
// for a given flow. We must also always insert a new item, or successfully
|
||||||
|
// coalesce with an existing item, for the same reason.
|
||||||
|
item := items[len(items)-1]
|
||||||
|
can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
|
||||||
|
var pktCSumKnownInvalid bool
|
||||||
|
if can == coalesceAppend {
|
||||||
|
result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
|
||||||
|
switch result {
|
||||||
|
case coalesceSuccess:
|
||||||
|
table.updateAt(item, len(items)-1)
|
||||||
|
return groResultCoalesced
|
||||||
|
case coalesceItemInvalidCSum:
|
||||||
|
// If the existing item has an invalid csum we take no action. A new
|
||||||
|
// item will be stored after it, and the existing item will never be
|
||||||
|
// revisited as part of future coalescing candidacy checks.
|
||||||
|
case coalescePktInvalidCSum:
|
||||||
|
// We must insert a new item, but we also mark it as invalid csum
|
||||||
|
// to prevent a repeat checksum validation.
|
||||||
|
pktCSumKnownInvalid = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// failed to coalesce with any other packets; store the item in the flow
|
||||||
|
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
|
||||||
|
return groResultTableInsert
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
||||||
|
// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
|
||||||
|
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
||||||
|
// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
|
||||||
|
// supported.
|
||||||
|
func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
|
||||||
|
for i := range bufs {
|
||||||
|
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
||||||
|
return errors.New("invalid offset")
|
||||||
|
}
|
||||||
|
var result groResult
|
||||||
|
switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
|
||||||
|
case tcp4GROCandidate:
|
||||||
|
result = tcpGRO(bufs, offset, i, tcpTable, false)
|
||||||
|
case tcp6GROCandidate:
|
||||||
|
result = tcpGRO(bufs, offset, i, tcpTable, true)
|
||||||
|
case udp4GROCandidate:
|
||||||
|
result = udpGRO(bufs, offset, i, udpTable, false)
|
||||||
|
case udp6GROCandidate:
|
||||||
|
result = udpGRO(bufs, offset, i, udpTable, true)
|
||||||
|
}
|
||||||
|
switch result {
|
||||||
|
case groResultNoop:
|
||||||
|
hdr := virtioNetHdr{}
|
||||||
|
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case groResultTableInsert:
|
||||||
|
*toWrite = append(*toWrite, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
|
||||||
|
errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
|
||||||
|
return errors.Join(errTCP, errUDP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoSplit splits packets from in into outBuffs, writing the size of each
|
||||||
|
// element into sizes. It returns the number of buffers populated, and/or an
|
||||||
|
// error.
|
||||||
|
func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
|
||||||
|
iphLen := int(hdr.csumStart)
|
||||||
|
srcAddrOffset := ipv6SrcAddrOffset
|
||||||
|
addrLen := 16
|
||||||
|
if !isV6 {
|
||||||
|
in[10], in[11] = 0, 0 // clear ipv4 header checksum
|
||||||
|
srcAddrOffset = ipv4SrcAddrOffset
|
||||||
|
addrLen = 4
|
||||||
|
}
|
||||||
|
transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||||
|
in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
|
||||||
|
var firstTCPSeqNum uint32
|
||||||
|
var protocol uint8
|
||||||
|
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||||
|
protocol = unix.IPPROTO_TCP
|
||||||
|
firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
|
||||||
|
} else {
|
||||||
|
protocol = unix.IPPROTO_UDP
|
||||||
|
}
|
||||||
|
nextSegmentDataAt := int(hdr.hdrLen)
|
||||||
|
i := 0
|
||||||
|
for ; nextSegmentDataAt < len(in); i++ {
|
||||||
|
if i == len(outBuffs) {
|
||||||
|
return i - 1, ErrTooManySegments
|
||||||
|
}
|
||||||
|
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
|
||||||
|
if nextSegmentEnd > len(in) {
|
||||||
|
nextSegmentEnd = len(in)
|
||||||
|
}
|
||||||
|
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
||||||
|
totalLen := int(hdr.hdrLen) + segmentDataLen
|
||||||
|
sizes[i] = totalLen
|
||||||
|
out := outBuffs[i][outOffset:]
|
||||||
|
|
||||||
|
copy(out, in[:iphLen])
|
||||||
|
if !isV6 {
|
||||||
|
// For IPv4 we are responsible for incrementing the ID field,
|
||||||
|
// updating the total len field, and recalculating the header
|
||||||
|
// checksum.
|
||||||
|
if i > 0 {
|
||||||
|
id := binary.BigEndian.Uint16(out[4:])
|
||||||
|
id += uint16(i)
|
||||||
|
binary.BigEndian.PutUint16(out[4:], id)
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
||||||
|
ipv4CSum := ^checksum(out[:iphLen], 0)
|
||||||
|
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
||||||
|
} else {
|
||||||
|
// For IPv6 we are responsible for updating the payload length field.
|
||||||
|
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy transport header
|
||||||
|
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
|
||||||
|
|
||||||
|
if protocol == unix.IPPROTO_TCP {
|
||||||
|
// set TCP seq and adjust TCP flags
|
||||||
|
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
|
||||||
|
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
|
||||||
|
if nextSegmentEnd != len(in) {
|
||||||
|
// FIN and PSH should only be set on last segment
|
||||||
|
clearFlags := tcpFlagFIN | tcpFlagPSH
|
||||||
|
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// set UDP header len
|
||||||
|
binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
// payload
|
||||||
|
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
||||||
|
|
||||||
|
// transport checksum
|
||||||
|
transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
|
||||||
|
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
|
||||||
|
transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
|
||||||
|
transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
|
||||||
|
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
|
||||||
|
|
||||||
|
nextSegmentDataAt += int(hdr.gsoSize)
|
||||||
|
}
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
|
||||||
|
cSumAt := cSumStart + cSumOffset
|
||||||
|
// The initial value at the checksum offset should be summed with the
|
||||||
|
// checksum we compute. This is typically the pseudo-header checksum.
|
||||||
|
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
||||||
|
in[cSumAt], in[cSumAt+1] = 0, 0
|
||||||
|
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
|
||||||
|
return nil
|
||||||
|
}
|
752
tun/offload_linux_test.go
Normal file
752
tun/offload_linux_test.go
Normal file
|
@ -0,0 +1,752 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package tun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
offset = virtioNetHdrLen
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
|
||||||
|
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
|
||||||
|
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
|
||||||
|
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
|
||||||
|
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
|
||||||
|
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
|
||||||
|
)
|
||||||
|
|
||||||
|
func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
|
||||||
|
totalLen := 28 + payloadLen
|
||||||
|
b := make([]byte, offset+int(totalLen), 65535)
|
||||||
|
ipv4H := header.IPv4(b[offset:])
|
||||||
|
srcAs4 := srcIPPort.Addr().As4()
|
||||||
|
dstAs4 := dstIPPort.Addr().As4()
|
||||||
|
ipFields := &header.IPv4Fields{
|
||||||
|
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
|
||||||
|
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
|
||||||
|
Protocol: unix.IPPROTO_UDP,
|
||||||
|
TTL: 64,
|
||||||
|
TotalLength: uint16(totalLen),
|
||||||
|
}
|
||||||
|
if ipFn != nil {
|
||||||
|
ipFn(ipFields)
|
||||||
|
}
|
||||||
|
ipv4H.Encode(ipFields)
|
||||||
|
udpH := header.UDP(b[offset+20:])
|
||||||
|
udpH.Encode(&header.UDPFields{
|
||||||
|
SrcPort: srcIPPort.Port(),
|
||||||
|
DstPort: dstIPPort.Port(),
|
||||||
|
Length: uint16(payloadLen + udphLen),
|
||||||
|
})
|
||||||
|
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
||||||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
|
||||||
|
udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
|
||||||
|
return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
|
||||||
|
totalLen := 48 + payloadLen
|
||||||
|
b := make([]byte, offset+int(totalLen), 65535)
|
||||||
|
ipv6H := header.IPv6(b[offset:])
|
||||||
|
srcAs16 := srcIPPort.Addr().As16()
|
||||||
|
dstAs16 := dstIPPort.Addr().As16()
|
||||||
|
ipFields := &header.IPv6Fields{
|
||||||
|
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
|
||||||
|
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
|
||||||
|
TransportProtocol: unix.IPPROTO_UDP,
|
||||||
|
HopLimit: 64,
|
||||||
|
PayloadLength: uint16(payloadLen + udphLen),
|
||||||
|
}
|
||||||
|
if ipFn != nil {
|
||||||
|
ipFn(ipFields)
|
||||||
|
}
|
||||||
|
ipv6H.Encode(ipFields)
|
||||||
|
udpH := header.UDP(b[offset+40:])
|
||||||
|
udpH.Encode(&header.UDPFields{
|
||||||
|
SrcPort: srcIPPort.Port(),
|
||||||
|
DstPort: dstIPPort.Port(),
|
||||||
|
Length: uint16(payloadLen + udphLen),
|
||||||
|
})
|
||||||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
|
||||||
|
udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
|
||||||
|
return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
|
||||||
|
totalLen := 40 + segmentSize
|
||||||
|
b := make([]byte, offset+int(totalLen), 65535)
|
||||||
|
ipv4H := header.IPv4(b[offset:])
|
||||||
|
srcAs4 := srcIPPort.Addr().As4()
|
||||||
|
dstAs4 := dstIPPort.Addr().As4()
|
||||||
|
ipFields := &header.IPv4Fields{
|
||||||
|
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
|
||||||
|
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
|
||||||
|
Protocol: unix.IPPROTO_TCP,
|
||||||
|
TTL: 64,
|
||||||
|
TotalLength: uint16(totalLen),
|
||||||
|
}
|
||||||
|
if ipFn != nil {
|
||||||
|
ipFn(ipFields)
|
||||||
|
}
|
||||||
|
ipv4H.Encode(ipFields)
|
||||||
|
tcpH := header.TCP(b[offset+20:])
|
||||||
|
tcpH.Encode(&header.TCPFields{
|
||||||
|
SrcPort: srcIPPort.Port(),
|
||||||
|
DstPort: dstIPPort.Port(),
|
||||||
|
SeqNum: seq,
|
||||||
|
AckNum: 1,
|
||||||
|
DataOffset: 20,
|
||||||
|
Flags: flags,
|
||||||
|
WindowSize: 3000,
|
||||||
|
})
|
||||||
|
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
||||||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
|
||||||
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
||||||
|
return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
|
||||||
|
totalLen := 60 + segmentSize
|
||||||
|
b := make([]byte, offset+int(totalLen), 65535)
|
||||||
|
ipv6H := header.IPv6(b[offset:])
|
||||||
|
srcAs16 := srcIPPort.Addr().As16()
|
||||||
|
dstAs16 := dstIPPort.Addr().As16()
|
||||||
|
ipFields := &header.IPv6Fields{
|
||||||
|
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
|
||||||
|
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
|
||||||
|
TransportProtocol: unix.IPPROTO_TCP,
|
||||||
|
HopLimit: 64,
|
||||||
|
PayloadLength: uint16(segmentSize + 20),
|
||||||
|
}
|
||||||
|
if ipFn != nil {
|
||||||
|
ipFn(ipFields)
|
||||||
|
}
|
||||||
|
ipv6H.Encode(ipFields)
|
||||||
|
tcpH := header.TCP(b[offset+40:])
|
||||||
|
tcpH.Encode(&header.TCPFields{
|
||||||
|
SrcPort: srcIPPort.Port(),
|
||||||
|
DstPort: dstIPPort.Port(),
|
||||||
|
SeqNum: seq,
|
||||||
|
AckNum: 1,
|
||||||
|
DataOffset: 20,
|
||||||
|
Flags: flags,
|
||||||
|
WindowSize: 3000,
|
||||||
|
})
|
||||||
|
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
|
||||||
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
||||||
|
return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_handleVirtioRead(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hdr virtioNetHdr
|
||||||
|
pktIn []byte
|
||||||
|
wantLens []int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"tcp4",
|
||||||
|
virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
||||||
|
gsoSize: 100,
|
||||||
|
hdrLen: 40,
|
||||||
|
csumStart: 20,
|
||||||
|
csumOffset: 16,
|
||||||
|
},
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
||||||
|
[]int{140, 140},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tcp6",
|
||||||
|
virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
||||||
|
gsoSize: 100,
|
||||||
|
hdrLen: 60,
|
||||||
|
csumStart: 40,
|
||||||
|
csumOffset: 16,
|
||||||
|
},
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
||||||
|
[]int{160, 160},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp4",
|
||||||
|
virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||||||
|
gsoSize: 100,
|
||||||
|
hdrLen: 28,
|
||||||
|
csumStart: 20,
|
||||||
|
csumOffset: 6,
|
||||||
|
},
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 200),
|
||||||
|
[]int{128, 128},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp6",
|
||||||
|
virtioNetHdr{
|
||||||
|
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
||||||
|
gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
|
||||||
|
gsoSize: 100,
|
||||||
|
hdrLen: 48,
|
||||||
|
csumStart: 40,
|
||||||
|
csumOffset: 6,
|
||||||
|
},
|
||||||
|
udp6Packet(ip6PortA, ip6PortB, 200),
|
||||||
|
[]int{148, 148},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
out := make([][]byte, conn.IdealBatchSize)
|
||||||
|
sizes := make([]int, conn.IdealBatchSize)
|
||||||
|
for i := range out {
|
||||||
|
out[i] = make([]byte, 65535)
|
||||||
|
}
|
||||||
|
tt.hdr.encode(tt.pktIn)
|
||||||
|
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
|
||||||
|
if err != nil {
|
||||||
|
if tt.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("got err: %v", err)
|
||||||
|
}
|
||||||
|
if n != len(tt.wantLens) {
|
||||||
|
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
|
||||||
|
}
|
||||||
|
for i := range tt.wantLens {
|
||||||
|
if tt.wantLens[i] != sizes[i] {
|
||||||
|
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func flipTCP4Checksum(b []byte) []byte {
|
||||||
|
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
|
||||||
|
b[at] ^= 0xFF
|
||||||
|
b[at+1] ^= 0xFF
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func flipUDP4Checksum(b []byte) []byte {
|
||||||
|
at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
|
||||||
|
b[at] ^= 0xFF
|
||||||
|
b[at+1] ^= 0xFF
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func Fuzz_handleGRO(f *testing.F) {
|
||||||
|
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
|
||||||
|
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
|
||||||
|
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
|
||||||
|
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
|
||||||
|
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
|
||||||
|
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
|
||||||
|
pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
|
||||||
|
pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
|
||||||
|
pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
|
||||||
|
pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
|
||||||
|
pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
|
||||||
|
pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
|
||||||
|
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
|
||||||
|
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
|
||||||
|
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
|
||||||
|
toWrite := make([]int, 0, len(pkts))
|
||||||
|
handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
|
||||||
|
if len(toWrite) > len(pkts) {
|
||||||
|
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
|
||||||
|
}
|
||||||
|
seenWriteI := make(map[int]bool)
|
||||||
|
for _, writeI := range toWrite {
|
||||||
|
if writeI < 0 || writeI > len(pkts)-1 {
|
||||||
|
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
|
||||||
|
}
|
||||||
|
if seenWriteI[writeI] {
|
||||||
|
t.Errorf("duplicate toWrite value: %d", writeI)
|
||||||
|
}
|
||||||
|
seenWriteI[writeI] = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_handleGRO(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pktsIn [][]byte
|
||||||
|
canUDPGRO bool
|
||||||
|
wantToWrite []int
|
||||||
|
wantLens []int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"multiple protocols and flows",
|
||||||
|
[][]byte{
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||||||
|
udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
|
||||||
|
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
|
||||||
|
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||||||
|
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||||||
|
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 1, 2, 4, 5, 7, 9},
|
||||||
|
[]int{240, 228, 128, 140, 260, 160, 248},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple protocols and flows no UDP GRO",
|
||||||
|
[][]byte{
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||||||
|
udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
|
||||||
|
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
|
||||||
|
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
|
||||||
|
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||||||
|
udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
[]int{0, 1, 2, 4, 5, 7, 8, 9, 10},
|
||||||
|
[]int{240, 128, 128, 140, 260, 160, 128, 148, 148},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"PSH interleaved",
|
||||||
|
[][]byte{
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 2, 4, 6},
|
||||||
|
[]int{240, 240, 260, 260},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"coalesceItemInvalidCSum",
|
||||||
|
[][]byte{
|
||||||
|
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
||||||
|
flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 1, 3, 4},
|
||||||
|
[]int{140, 240, 128, 228},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"out of order",
|
||||||
|
[][]byte{
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0},
|
||||||
|
[]int{340},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"unequal TTL",
|
||||||
|
[][]byte{
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||||||
|
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||||||
|
fields.TTL++
|
||||||
|
}),
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||||
|
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||||||
|
fields.TTL++
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 1, 2, 3},
|
||||||
|
[]int{140, 140, 128, 128},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"unequal ToS",
|
||||||
|
[][]byte{
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||||||
|
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||||||
|
fields.TOS++
|
||||||
|
}),
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||||
|
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||||||
|
fields.TOS++
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 1, 2, 3},
|
||||||
|
[]int{140, 140, 128, 128},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"unequal flags more fragments set",
|
||||||
|
[][]byte{
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||||||
|
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||||||
|
fields.Flags = 1
|
||||||
|
}),
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||||
|
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||||||
|
fields.Flags = 1
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 1, 2, 3},
|
||||||
|
[]int{140, 140, 128, 128},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"unequal flags DF set",
|
||||||
|
[][]byte{
|
||||||
|
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
||||||
|
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
||||||
|
fields.Flags = 2
|
||||||
|
}),
|
||||||
|
udp4Packet(ip4PortA, ip4PortB, 100),
|
||||||
|
udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
|
||||||
|
fields.Flags = 2
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 1, 2, 3},
|
||||||
|
[]int{140, 140, 128, 128},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ipv6 unequal hop limit",
|
||||||
|
[][]byte{
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
|
||||||
|
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
|
||||||
|
fields.HopLimit++
|
||||||
|
}),
|
||||||
|
udp6Packet(ip6PortA, ip6PortB, 100),
|
||||||
|
udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
|
||||||
|
fields.HopLimit++
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 1, 2, 3},
|
||||||
|
[]int{160, 160, 148, 148},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ipv6 unequal traffic class",
|
||||||
|
[][]byte{
|
||||||
|
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
|
||||||
|
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
|
||||||
|
fields.TrafficClass++
|
||||||
|
}),
|
||||||
|
udp6Packet(ip6PortA, ip6PortB, 100),
|
||||||
|
udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
|
||||||
|
fields.TrafficClass++
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
[]int{0, 1, 2, 3},
|
||||||
|
[]int{160, 160, 148, 148},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
toWrite := make([]int, 0, len(tt.pktsIn))
|
||||||
|
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
|
||||||
|
if err != nil {
|
||||||
|
if tt.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("got err: %v", err)
|
||||||
|
}
|
||||||
|
if len(toWrite) != len(tt.wantToWrite) {
|
||||||
|
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
|
||||||
|
}
|
||||||
|
for i, pktI := range tt.wantToWrite {
|
||||||
|
if tt.wantToWrite[i] != toWrite[i] {
|
||||||
|
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
|
||||||
|
}
|
||||||
|
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
|
||||||
|
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_packetIsGROCandidate(t *testing.T) {
|
||||||
|
tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
|
||||||
|
tcp4TooShort := tcp4[:39]
|
||||||
|
ip4InvalidHeaderLen := make([]byte, len(tcp4))
|
||||||
|
copy(ip4InvalidHeaderLen, tcp4)
|
||||||
|
ip4InvalidHeaderLen[0] = 0x46
|
||||||
|
ip4InvalidProtocol := make([]byte, len(tcp4))
|
||||||
|
copy(ip4InvalidProtocol, tcp4)
|
||||||
|
ip4InvalidProtocol[9] = unix.IPPROTO_GRE
|
||||||
|
|
||||||
|
tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
|
||||||
|
tcp6TooShort := tcp6[:59]
|
||||||
|
ip6InvalidProtocol := make([]byte, len(tcp6))
|
||||||
|
copy(ip6InvalidProtocol, tcp6)
|
||||||
|
ip6InvalidProtocol[6] = unix.IPPROTO_GRE
|
||||||
|
|
||||||
|
udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
|
||||||
|
udp4TooShort := udp4[:27]
|
||||||
|
|
||||||
|
udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
|
||||||
|
udp6TooShort := udp6[:47]
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
b []byte
|
||||||
|
canUDPGRO bool
|
||||||
|
want groCandidateType
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"tcp4",
|
||||||
|
tcp4,
|
||||||
|
true,
|
||||||
|
tcp4GROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tcp6",
|
||||||
|
tcp6,
|
||||||
|
true,
|
||||||
|
tcp6GROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp4",
|
||||||
|
udp4,
|
||||||
|
true,
|
||||||
|
udp4GROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp4 no support",
|
||||||
|
udp4,
|
||||||
|
false,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp6",
|
||||||
|
udp6,
|
||||||
|
true,
|
||||||
|
udp6GROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp6 no support",
|
||||||
|
udp6,
|
||||||
|
false,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp4 too short",
|
||||||
|
udp4TooShort,
|
||||||
|
true,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"udp6 too short",
|
||||||
|
udp6TooShort,
|
||||||
|
true,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tcp4 too short",
|
||||||
|
tcp4TooShort,
|
||||||
|
true,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tcp6 too short",
|
||||||
|
tcp6TooShort,
|
||||||
|
true,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid IP version",
|
||||||
|
[]byte{0x00},
|
||||||
|
true,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid IP header len",
|
||||||
|
ip4InvalidHeaderLen,
|
||||||
|
true,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ip4 invalid protocol",
|
||||||
|
ip4InvalidProtocol,
|
||||||
|
true,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ip6 invalid protocol",
|
||||||
|
ip6InvalidProtocol,
|
||||||
|
true,
|
||||||
|
notGROCandidate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
|
||||||
|
t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_udpPacketsCanCoalesce(t *testing.T) {
|
||||||
|
udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
|
||||||
|
udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
|
||||||
|
udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
pkt []byte
|
||||||
|
iphLen uint8
|
||||||
|
gsoSize uint16
|
||||||
|
item udpGROItem
|
||||||
|
bufs [][]byte
|
||||||
|
bufsOffset int
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want canCoalesce
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"coalesceAppend equal gso",
|
||||||
|
args{
|
||||||
|
pkt: udp4a[offset:],
|
||||||
|
iphLen: 20,
|
||||||
|
gsoSize: 100,
|
||||||
|
item: udpGROItem{
|
||||||
|
gsoSize: 100,
|
||||||
|
iphLen: 20,
|
||||||
|
},
|
||||||
|
bufs: [][]byte{
|
||||||
|
udp4a,
|
||||||
|
udp4b,
|
||||||
|
},
|
||||||
|
bufsOffset: offset,
|
||||||
|
},
|
||||||
|
coalesceAppend,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"coalesceAppend smaller gso",
|
||||||
|
args{
|
||||||
|
pkt: udp4a[offset : len(udp4a)-90],
|
||||||
|
iphLen: 20,
|
||||||
|
gsoSize: 10,
|
||||||
|
item: udpGROItem{
|
||||||
|
gsoSize: 100,
|
||||||
|
iphLen: 20,
|
||||||
|
},
|
||||||
|
bufs: [][]byte{
|
||||||
|
udp4a,
|
||||||
|
udp4b,
|
||||||
|
},
|
||||||
|
bufsOffset: offset,
|
||||||
|
},
|
||||||
|
coalesceAppend,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"coalesceUnavailable smaller gso previously appended",
|
||||||
|
args{
|
||||||
|
pkt: udp4a[offset:],
|
||||||
|
iphLen: 20,
|
||||||
|
gsoSize: 100,
|
||||||
|
item: udpGROItem{
|
||||||
|
gsoSize: 100,
|
||||||
|
iphLen: 20,
|
||||||
|
},
|
||||||
|
bufs: [][]byte{
|
||||||
|
udp4c,
|
||||||
|
udp4b,
|
||||||
|
},
|
||||||
|
bufsOffset: offset,
|
||||||
|
},
|
||||||
|
coalesceUnavailable,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"coalesceUnavailable larger following smaller",
|
||||||
|
args{
|
||||||
|
pkt: udp4c[offset:],
|
||||||
|
iphLen: 20,
|
||||||
|
gsoSize: 110,
|
||||||
|
item: udpGROItem{
|
||||||
|
gsoSize: 100,
|
||||||
|
iphLen: 20,
|
||||||
|
},
|
||||||
|
bufs: [][]byte{
|
||||||
|
udp4a,
|
||||||
|
udp4c,
|
||||||
|
},
|
||||||
|
bufsOffset: offset,
|
||||||
|
},
|
||||||
|
coalesceUnavailable,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
|
||||||
|
t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,627 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package tun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const tcpFlagsOffset = 13
|
|
||||||
|
|
||||||
const (
|
|
||||||
tcpFlagFIN uint8 = 0x01
|
|
||||||
tcpFlagPSH uint8 = 0x08
|
|
||||||
tcpFlagACK uint8 = 0x10
|
|
||||||
)
|
|
||||||
|
|
||||||
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
|
|
||||||
// kernel symbol is virtio_net_hdr.
|
|
||||||
type virtioNetHdr struct {
|
|
||||||
flags uint8
|
|
||||||
gsoType uint8
|
|
||||||
hdrLen uint16
|
|
||||||
gsoSize uint16
|
|
||||||
csumStart uint16
|
|
||||||
csumOffset uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *virtioNetHdr) decode(b []byte) error {
|
|
||||||
if len(b) < virtioNetHdrLen {
|
|
||||||
return io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *virtioNetHdr) encode(b []byte) error {
|
|
||||||
if len(b) < virtioNetHdrLen {
|
|
||||||
return io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
|
|
||||||
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
|
|
||||||
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
|
|
||||||
)
|
|
||||||
|
|
||||||
// flowKey represents the key for a flow.
|
|
||||||
type flowKey struct {
|
|
||||||
srcAddr, dstAddr [16]byte
|
|
||||||
srcPort, dstPort uint16
|
|
||||||
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
|
|
||||||
type tcpGROTable struct {
|
|
||||||
itemsByFlow map[flowKey][]tcpGROItem
|
|
||||||
itemsPool [][]tcpGROItem
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTCPGROTable() *tcpGROTable {
|
|
||||||
t := &tcpGROTable{
|
|
||||||
itemsByFlow: make(map[flowKey][]tcpGROItem, conn.IdealBatchSize),
|
|
||||||
itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
|
|
||||||
}
|
|
||||||
for i := range t.itemsPool {
|
|
||||||
t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
|
|
||||||
key := flowKey{}
|
|
||||||
addrSize := dstAddr - srcAddr
|
|
||||||
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
|
|
||||||
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
|
|
||||||
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
|
|
||||||
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
|
|
||||||
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupOrInsert looks up a flow for the provided packet and metadata,
|
|
||||||
// returning the packets found for the flow, or inserting a new one if none
|
|
||||||
// is found.
|
|
||||||
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
|
|
||||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
|
||||||
items, ok := t.itemsByFlow[key]
|
|
||||||
if ok {
|
|
||||||
return items, ok
|
|
||||||
}
|
|
||||||
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
|
|
||||||
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert an item in the table for the provided packet and packet metadata.
|
|
||||||
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
|
|
||||||
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
|
|
||||||
item := tcpGROItem{
|
|
||||||
key: key,
|
|
||||||
bufsIndex: uint16(bufsIndex),
|
|
||||||
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
|
|
||||||
iphLen: uint8(tcphOffset),
|
|
||||||
tcphLen: uint8(tcphLen),
|
|
||||||
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
|
|
||||||
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
|
|
||||||
}
|
|
||||||
items, ok := t.itemsByFlow[key]
|
|
||||||
if !ok {
|
|
||||||
items = t.newItems()
|
|
||||||
}
|
|
||||||
items = append(items, item)
|
|
||||||
t.itemsByFlow[key] = items
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
|
|
||||||
items, _ := t.itemsByFlow[item.key]
|
|
||||||
items[i] = item
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
|
|
||||||
items, _ := t.itemsByFlow[key]
|
|
||||||
items = append(items[:i], items[i+1:]...)
|
|
||||||
t.itemsByFlow[key] = items
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
|
|
||||||
// of a GRO evaluation across a vector of packets.
|
|
||||||
type tcpGROItem struct {
|
|
||||||
key flowKey
|
|
||||||
sentSeq uint32 // the sequence number
|
|
||||||
bufsIndex uint16 // the index into the original bufs slice
|
|
||||||
numMerged uint16 // the number of packets merged into this item
|
|
||||||
gsoSize uint16 // payload size
|
|
||||||
iphLen uint8 // ip header len
|
|
||||||
tcphLen uint8 // tcp header len
|
|
||||||
pshSet bool // psh flag is set
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) newItems() []tcpGROItem {
|
|
||||||
var items []tcpGROItem
|
|
||||||
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
|
|
||||||
return items
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tcpGROTable) reset() {
|
|
||||||
for k, items := range t.itemsByFlow {
|
|
||||||
items = items[:0]
|
|
||||||
t.itemsPool = append(t.itemsPool, items)
|
|
||||||
delete(t.itemsByFlow, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// canCoalesce represents the outcome of checking if two TCP packets are
|
|
||||||
// candidates for coalescing.
|
|
||||||
type canCoalesce int
|
|
||||||
|
|
||||||
const (
|
|
||||||
coalescePrepend canCoalesce = -1
|
|
||||||
coalesceUnavailable canCoalesce = 0
|
|
||||||
coalesceAppend canCoalesce = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
|
|
||||||
// described by item. This function makes considerations that match the kernel's
|
|
||||||
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
|
|
||||||
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
|
|
||||||
pktTarget := bufs[item.bufsIndex][bufsOffset:]
|
|
||||||
if tcphLen != item.tcphLen {
|
|
||||||
// cannot coalesce with unequal tcp options len
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if tcphLen > 20 {
|
|
||||||
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
|
|
||||||
// cannot coalesce with unequal tcp options
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if pkt[0]>>4 == 6 {
|
|
||||||
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
|
|
||||||
// cannot coalesce with unequal Traffic class values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[7] != pktTarget[7] {
|
|
||||||
// cannot coalesce with unequal Hop limit values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if pkt[1] != pktTarget[1] {
|
|
||||||
// cannot coalesce with unequal ToS values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[6]>>5 != pktTarget[6]>>5 {
|
|
||||||
// cannot coalesce with unequal DF or reserved bits. MF is checked
|
|
||||||
// further up the stack.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if pkt[8] != pktTarget[8] {
|
|
||||||
// cannot coalesce with unequal TTL values
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// seq adjacency
|
|
||||||
lhsLen := item.gsoSize
|
|
||||||
lhsLen += item.numMerged * item.gsoSize
|
|
||||||
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
|
|
||||||
if item.pshSet {
|
|
||||||
// We cannot append to a segment that has the PSH flag set, PSH
|
|
||||||
// can only be set on the final segment in a reassembled group.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
|
|
||||||
// A smaller than gsoSize packet has been appended previously.
|
|
||||||
// Nothing can come after a smaller packet on the end.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize > item.gsoSize {
|
|
||||||
// We cannot have a larger packet following a smaller one.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
return coalesceAppend
|
|
||||||
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
|
|
||||||
if pshSet {
|
|
||||||
// We cannot prepend with a segment that has the PSH flag set, PSH
|
|
||||||
// can only be set on the final segment in a reassembled group.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize < item.gsoSize {
|
|
||||||
// We cannot have a larger packet following a smaller one.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
if gsoSize > item.gsoSize && item.numMerged > 0 {
|
|
||||||
// There's at least one previous merge, and we're larger than all
|
|
||||||
// previous. This would put multiple smaller packets on the end.
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
return coalescePrepend
|
|
||||||
}
|
|
||||||
return coalesceUnavailable
|
|
||||||
}
|
|
||||||
|
|
||||||
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
|
|
||||||
srcAddrAt := ipv4SrcAddrOffset
|
|
||||||
addrSize := 4
|
|
||||||
if isV6 {
|
|
||||||
srcAddrAt = ipv6SrcAddrOffset
|
|
||||||
addrSize = 16
|
|
||||||
}
|
|
||||||
tcpTotalLen := uint16(len(pkt) - int(iphLen))
|
|
||||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
|
|
||||||
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// coalesceResult represents the result of attempting to coalesce two TCP
|
|
||||||
// packets.
|
|
||||||
type coalesceResult int
|
|
||||||
|
|
||||||
const (
|
|
||||||
coalesceInsufficientCap coalesceResult = 0
|
|
||||||
coalescePSHEnding coalesceResult = 1
|
|
||||||
coalesceItemInvalidCSum coalesceResult = 2
|
|
||||||
coalescePktInvalidCSum coalesceResult = 3
|
|
||||||
coalesceSuccess coalesceResult = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
|
|
||||||
// item, returning the outcome. This function may swap bufs elements in the
|
|
||||||
// event of a prepend as item's bufs index is already being tracked for writing
|
|
||||||
// to a Device.
|
|
||||||
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
|
|
||||||
var pktHead []byte // the packet that will end up at the front
|
|
||||||
headersLen := item.iphLen + item.tcphLen
|
|
||||||
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
|
|
||||||
|
|
||||||
// Copy data
|
|
||||||
if mode == coalescePrepend {
|
|
||||||
pktHead = pkt
|
|
||||||
if cap(pkt)-bufsOffset < coalescedLen {
|
|
||||||
// We don't want to allocate a new underlying array if capacity is
|
|
||||||
// too small.
|
|
||||||
return coalesceInsufficientCap
|
|
||||||
}
|
|
||||||
if pshSet {
|
|
||||||
return coalescePSHEnding
|
|
||||||
}
|
|
||||||
if item.numMerged == 0 {
|
|
||||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
|
||||||
return coalesceItemInvalidCSum
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
|
||||||
return coalescePktInvalidCSum
|
|
||||||
}
|
|
||||||
item.sentSeq = seq
|
|
||||||
extendBy := coalescedLen - len(pktHead)
|
|
||||||
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
|
|
||||||
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
|
|
||||||
// Flip the slice headers in bufs as part of prepend. The index of item
|
|
||||||
// is already being tracked for writing.
|
|
||||||
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
|
|
||||||
} else {
|
|
||||||
pktHead = bufs[item.bufsIndex][bufsOffset:]
|
|
||||||
if cap(pktHead)-bufsOffset < coalescedLen {
|
|
||||||
// We don't want to allocate a new underlying array if capacity is
|
|
||||||
// too small.
|
|
||||||
return coalesceInsufficientCap
|
|
||||||
}
|
|
||||||
if item.numMerged == 0 {
|
|
||||||
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
|
|
||||||
return coalesceItemInvalidCSum
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
|
|
||||||
return coalescePktInvalidCSum
|
|
||||||
}
|
|
||||||
if pshSet {
|
|
||||||
// We are appending a segment with PSH set.
|
|
||||||
item.pshSet = pshSet
|
|
||||||
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
|
|
||||||
}
|
|
||||||
extendBy := len(pkt) - int(headersLen)
|
|
||||||
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
|
|
||||||
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
|
|
||||||
}
|
|
||||||
|
|
||||||
if gsoSize > item.gsoSize {
|
|
||||||
item.gsoSize = gsoSize
|
|
||||||
}
|
|
||||||
hdr := virtioNetHdr{
|
|
||||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
|
|
||||||
hdrLen: uint16(headersLen),
|
|
||||||
gsoSize: uint16(item.gsoSize),
|
|
||||||
csumStart: uint16(item.iphLen),
|
|
||||||
csumOffset: 16,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
|
|
||||||
// (IPv4) header checksum.
|
|
||||||
if isV6 {
|
|
||||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
|
|
||||||
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
|
|
||||||
} else {
|
|
||||||
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
|
|
||||||
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
|
|
||||||
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
|
|
||||||
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
|
|
||||||
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
|
|
||||||
}
|
|
||||||
hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
|
|
||||||
|
|
||||||
// Calculate the pseudo header checksum and place it at the TCP checksum
|
|
||||||
// offset. Downstream checksum offloading will combine this with computation
|
|
||||||
// of the tcp header and payload checksum.
|
|
||||||
addrLen := 4
|
|
||||||
addrOffset := ipv4SrcAddrOffset
|
|
||||||
if isV6 {
|
|
||||||
addrLen = 16
|
|
||||||
addrOffset = ipv6SrcAddrOffset
|
|
||||||
}
|
|
||||||
srcAddrAt := bufsOffset + addrOffset
|
|
||||||
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
|
|
||||||
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
|
|
||||||
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
|
|
||||||
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
|
|
||||||
|
|
||||||
item.numMerged++
|
|
||||||
return coalesceSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4FlagMoreFragments uint8 = 0x20
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ipv4SrcAddrOffset = 12
|
|
||||||
ipv6SrcAddrOffset = 8
|
|
||||||
maxUint16 = 1<<16 - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
|
|
||||||
// existing packets tracked in table. It will return false when pktI is not
|
|
||||||
// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
|
|
||||||
// should be written to the Device.
|
|
||||||
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
|
|
||||||
pkt := bufs[pktI][offset:]
|
|
||||||
if len(pkt) > maxUint16 {
|
|
||||||
// A valid IPv4 or IPv6 packet will never exceed this.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
iphLen := int((pkt[0] & 0x0F) * 4)
|
|
||||||
if isV6 {
|
|
||||||
iphLen = 40
|
|
||||||
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
|
|
||||||
if ipv6HPayloadLen != len(pkt)-iphLen {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
|
|
||||||
if totalLen != len(pkt) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(pkt) < iphLen {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
|
|
||||||
if tcphLen < 20 || tcphLen > 60 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len(pkt) < iphLen+tcphLen {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !isV6 {
|
|
||||||
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
|
|
||||||
// no GRO support for fragmented segments for now
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tcpFlags := pkt[iphLen+tcpFlagsOffset]
|
|
||||||
var pshSet bool
|
|
||||||
// not a candidate if any non-ACK flags (except PSH+ACK) are set
|
|
||||||
if tcpFlags != tcpFlagACK {
|
|
||||||
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
pshSet = true
|
|
||||||
}
|
|
||||||
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
|
|
||||||
// not a candidate if payload len is 0
|
|
||||||
if gsoSize < 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
|
|
||||||
srcAddrOffset := ipv4SrcAddrOffset
|
|
||||||
addrLen := 4
|
|
||||||
if isV6 {
|
|
||||||
srcAddrOffset = ipv6SrcAddrOffset
|
|
||||||
addrLen = 16
|
|
||||||
}
|
|
||||||
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
|
||||||
if !existing {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := len(items) - 1; i >= 0; i-- {
|
|
||||||
// In the best case of packets arriving in order iterating in reverse is
|
|
||||||
// more efficient if there are multiple items for a given flow. This
|
|
||||||
// also enables a natural table.deleteAt() in the
|
|
||||||
// coalesceItemInvalidCSum case without the need for index tracking.
|
|
||||||
// This algorithm makes a best effort to coalesce in the event of
|
|
||||||
// unordered packets, where pkt may land anywhere in items from a
|
|
||||||
// sequence number perspective, however once an item is inserted into
|
|
||||||
// the table it is never compared across other items later.
|
|
||||||
item := items[i]
|
|
||||||
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
|
|
||||||
if can != coalesceUnavailable {
|
|
||||||
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
|
|
||||||
switch result {
|
|
||||||
case coalesceSuccess:
|
|
||||||
table.updateAt(item, i)
|
|
||||||
return true
|
|
||||||
case coalesceItemInvalidCSum:
|
|
||||||
// delete the item with an invalid csum
|
|
||||||
table.deleteAt(item.key, i)
|
|
||||||
case coalescePktInvalidCSum:
|
|
||||||
// no point in inserting an item that we can't coalesce
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// failed to coalesce with any other packets; store the item in the flow
|
|
||||||
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTCP4NoIPOptions(b []byte) bool {
|
|
||||||
if len(b) < 40 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]>>4 != 4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]&0x0F != 5 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[9] != unix.IPPROTO_TCP {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTCP6NoEH(b []byte) bool {
|
|
||||||
if len(b) < 60 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[0]>>4 != 6 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if b[6] != unix.IPPROTO_TCP {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
|
|
||||||
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
|
|
||||||
// empty (but non-nil), and are passed in to save allocs as the caller may reset
|
|
||||||
// and recycle them across vectors of packets.
|
|
||||||
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
|
|
||||||
for i := range bufs {
|
|
||||||
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
|
|
||||||
return errors.New("invalid offset")
|
|
||||||
}
|
|
||||||
var coalesced bool
|
|
||||||
switch {
|
|
||||||
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
|
|
||||||
coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
|
|
||||||
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
|
|
||||||
coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
|
|
||||||
}
|
|
||||||
if !coalesced {
|
|
||||||
hdr := virtioNetHdr{}
|
|
||||||
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
*toWrite = append(*toWrite, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// tcpTSO splits packets from in into outBuffs, writing the size of each
|
|
||||||
// element into sizes. It returns the number of buffers populated, and/or an
|
|
||||||
// error.
|
|
||||||
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
|
|
||||||
iphLen := int(hdr.csumStart)
|
|
||||||
srcAddrOffset := ipv6SrcAddrOffset
|
|
||||||
addrLen := 16
|
|
||||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
|
||||||
in[10], in[11] = 0, 0 // clear ipv4 header checksum
|
|
||||||
srcAddrOffset = ipv4SrcAddrOffset
|
|
||||||
addrLen = 4
|
|
||||||
}
|
|
||||||
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
|
|
||||||
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
|
|
||||||
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
|
|
||||||
nextSegmentDataAt := int(hdr.hdrLen)
|
|
||||||
i := 0
|
|
||||||
for ; nextSegmentDataAt < len(in); i++ {
|
|
||||||
if i == len(outBuffs) {
|
|
||||||
return i - 1, ErrTooManySegments
|
|
||||||
}
|
|
||||||
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
|
|
||||||
if nextSegmentEnd > len(in) {
|
|
||||||
nextSegmentEnd = len(in)
|
|
||||||
}
|
|
||||||
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
|
|
||||||
totalLen := int(hdr.hdrLen) + segmentDataLen
|
|
||||||
sizes[i] = totalLen
|
|
||||||
out := outBuffs[i][outOffset:]
|
|
||||||
|
|
||||||
copy(out, in[:iphLen])
|
|
||||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
|
||||||
// For IPv4 we are responsible for incrementing the ID field,
|
|
||||||
// updating the total len field, and recalculating the header
|
|
||||||
// checksum.
|
|
||||||
if i > 0 {
|
|
||||||
id := binary.BigEndian.Uint16(out[4:])
|
|
||||||
id += uint16(i)
|
|
||||||
binary.BigEndian.PutUint16(out[4:], id)
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
|
|
||||||
ipv4CSum := ^checksum(out[:iphLen], 0)
|
|
||||||
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
|
|
||||||
} else {
|
|
||||||
// For IPv6 we are responsible for updating the payload length field.
|
|
||||||
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCP header
|
|
||||||
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
|
|
||||||
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
|
|
||||||
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
|
|
||||||
if nextSegmentEnd != len(in) {
|
|
||||||
// FIN and PSH should only be set on last segment
|
|
||||||
clearFlags := tcpFlagFIN | tcpFlagPSH
|
|
||||||
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
|
|
||||||
}
|
|
||||||
|
|
||||||
// payload
|
|
||||||
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
|
|
||||||
|
|
||||||
// TCP checksum
|
|
||||||
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
|
|
||||||
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
|
|
||||||
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
|
|
||||||
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
|
|
||||||
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
|
|
||||||
|
|
||||||
nextSegmentDataAt += int(hdr.gsoSize)
|
|
||||||
}
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
|
|
||||||
cSumAt := cSumStart + cSumOffset
|
|
||||||
// The initial value at the checksum offset should be summed with the
|
|
||||||
// checksum we compute. This is typically the pseudo-header checksum.
|
|
||||||
initial := binary.BigEndian.Uint16(in[cSumAt:])
|
|
||||||
in[cSumAt], in[cSumAt+1] = 0, 0
|
|
||||||
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,411 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package tun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
offset = virtioNetHdrLen
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
|
|
||||||
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
|
|
||||||
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
|
|
||||||
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
|
|
||||||
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
|
|
||||||
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
|
|
||||||
)
|
|
||||||
|
|
||||||
func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
|
|
||||||
totalLen := 40 + segmentSize
|
|
||||||
b := make([]byte, offset+int(totalLen), 65535)
|
|
||||||
ipv4H := header.IPv4(b[offset:])
|
|
||||||
srcAs4 := srcIPPort.Addr().As4()
|
|
||||||
dstAs4 := dstIPPort.Addr().As4()
|
|
||||||
ipFields := &header.IPv4Fields{
|
|
||||||
SrcAddr: tcpip.Address(srcAs4[:]),
|
|
||||||
DstAddr: tcpip.Address(dstAs4[:]),
|
|
||||||
Protocol: unix.IPPROTO_TCP,
|
|
||||||
TTL: 64,
|
|
||||||
TotalLength: uint16(totalLen),
|
|
||||||
}
|
|
||||||
if ipFn != nil {
|
|
||||||
ipFn(ipFields)
|
|
||||||
}
|
|
||||||
ipv4H.Encode(ipFields)
|
|
||||||
tcpH := header.TCP(b[offset+20:])
|
|
||||||
tcpH.Encode(&header.TCPFields{
|
|
||||||
SrcPort: srcIPPort.Port(),
|
|
||||||
DstPort: dstIPPort.Port(),
|
|
||||||
SeqNum: seq,
|
|
||||||
AckNum: 1,
|
|
||||||
DataOffset: 20,
|
|
||||||
Flags: flags,
|
|
||||||
WindowSize: 3000,
|
|
||||||
})
|
|
||||||
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
|
||||||
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
|
|
||||||
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
|
||||||
return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
|
|
||||||
totalLen := 60 + segmentSize
|
|
||||||
b := make([]byte, offset+int(totalLen), 65535)
|
|
||||||
ipv6H := header.IPv6(b[offset:])
|
|
||||||
srcAs16 := srcIPPort.Addr().As16()
|
|
||||||
dstAs16 := dstIPPort.Addr().As16()
|
|
||||||
ipFields := &header.IPv6Fields{
|
|
||||||
SrcAddr: tcpip.Address(srcAs16[:]),
|
|
||||||
DstAddr: tcpip.Address(dstAs16[:]),
|
|
||||||
TransportProtocol: unix.IPPROTO_TCP,
|
|
||||||
HopLimit: 64,
|
|
||||||
PayloadLength: uint16(segmentSize + 20),
|
|
||||||
}
|
|
||||||
if ipFn != nil {
|
|
||||||
ipFn(ipFields)
|
|
||||||
}
|
|
||||||
ipv6H.Encode(ipFields)
|
|
||||||
tcpH := header.TCP(b[offset+40:])
|
|
||||||
tcpH.Encode(&header.TCPFields{
|
|
||||||
SrcPort: srcIPPort.Port(),
|
|
||||||
DstPort: dstIPPort.Port(),
|
|
||||||
SeqNum: seq,
|
|
||||||
AckNum: 1,
|
|
||||||
DataOffset: 20,
|
|
||||||
Flags: flags,
|
|
||||||
WindowSize: 3000,
|
|
||||||
})
|
|
||||||
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
|
|
||||||
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
|
|
||||||
return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_handleVirtioRead(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
hdr virtioNetHdr
|
|
||||||
pktIn []byte
|
|
||||||
wantLens []int
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"tcp4",
|
|
||||||
virtioNetHdr{
|
|
||||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
|
||||||
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
|
|
||||||
gsoSize: 100,
|
|
||||||
hdrLen: 40,
|
|
||||||
csumStart: 20,
|
|
||||||
csumOffset: 16,
|
|
||||||
},
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
|
||||||
[]int{140, 140},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tcp6",
|
|
||||||
virtioNetHdr{
|
|
||||||
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
|
|
||||||
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
|
|
||||||
gsoSize: 100,
|
|
||||||
hdrLen: 60,
|
|
||||||
csumStart: 40,
|
|
||||||
csumOffset: 16,
|
|
||||||
},
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
|
|
||||||
[]int{160, 160},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
out := make([][]byte, conn.IdealBatchSize)
|
|
||||||
sizes := make([]int, conn.IdealBatchSize)
|
|
||||||
for i := range out {
|
|
||||||
out[i] = make([]byte, 65535)
|
|
||||||
}
|
|
||||||
tt.hdr.encode(tt.pktIn)
|
|
||||||
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
|
|
||||||
if err != nil {
|
|
||||||
if tt.wantErr {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("got err: %v", err)
|
|
||||||
}
|
|
||||||
if n != len(tt.wantLens) {
|
|
||||||
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
|
|
||||||
}
|
|
||||||
for i := range tt.wantLens {
|
|
||||||
if tt.wantLens[i] != sizes[i] {
|
|
||||||
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func flipTCP4Checksum(b []byte) []byte {
|
|
||||||
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
|
|
||||||
b[at] ^= 0xFF
|
|
||||||
b[at+1] ^= 0xFF
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func Fuzz_handleGRO(f *testing.F) {
|
|
||||||
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
|
|
||||||
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
|
|
||||||
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
|
|
||||||
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
|
|
||||||
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
|
|
||||||
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
|
|
||||||
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
|
|
||||||
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
|
|
||||||
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
|
|
||||||
toWrite := make([]int, 0, len(pkts))
|
|
||||||
handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
|
|
||||||
if len(toWrite) > len(pkts) {
|
|
||||||
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
|
|
||||||
}
|
|
||||||
seenWriteI := make(map[int]bool)
|
|
||||||
for _, writeI := range toWrite {
|
|
||||||
if writeI < 0 || writeI > len(pkts)-1 {
|
|
||||||
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
|
|
||||||
}
|
|
||||||
if seenWriteI[writeI] {
|
|
||||||
t.Errorf("duplicate toWrite value: %d", writeI)
|
|
||||||
}
|
|
||||||
seenWriteI[writeI] = true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_handleGRO(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
pktsIn [][]byte
|
|
||||||
wantToWrite []int
|
|
||||||
wantLens []int
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"multiple flows",
|
|
||||||
[][]byte{
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
|
|
||||||
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
|
|
||||||
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
|
|
||||||
},
|
|
||||||
[]int{0, 2, 3, 5},
|
|
||||||
[]int{240, 140, 260, 160},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"PSH interleaved",
|
|
||||||
[][]byte{
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
|
|
||||||
},
|
|
||||||
[]int{0, 2, 4, 6},
|
|
||||||
[]int{240, 240, 260, 260},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"coalesceItemInvalidCSum",
|
|
||||||
[][]byte{
|
|
||||||
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
|
||||||
},
|
|
||||||
[]int{0, 1},
|
|
||||||
[]int{140, 240},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"out of order",
|
|
||||||
[][]byte{
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
|
|
||||||
},
|
|
||||||
[]int{0},
|
|
||||||
[]int{340},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tcp4 unequal TTL",
|
|
||||||
[][]byte{
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
|
||||||
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
|
||||||
fields.TTL++
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
[]int{0, 1},
|
|
||||||
[]int{140, 140},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tcp4 unequal ToS",
|
|
||||||
[][]byte{
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
|
||||||
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
|
||||||
fields.TOS++
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
[]int{0, 1},
|
|
||||||
[]int{140, 140},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tcp4 unequal flags more fragments set",
|
|
||||||
[][]byte{
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
|
||||||
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
|
||||||
fields.Flags = 1
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
[]int{0, 1},
|
|
||||||
[]int{140, 140},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tcp4 unequal flags DF set",
|
|
||||||
[][]byte{
|
|
||||||
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
|
|
||||||
tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
|
|
||||||
fields.Flags = 2
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
[]int{0, 1},
|
|
||||||
[]int{140, 140},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tcp6 unequal hop limit",
|
|
||||||
[][]byte{
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
|
|
||||||
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
|
|
||||||
fields.HopLimit++
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
[]int{0, 1},
|
|
||||||
[]int{160, 160},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tcp6 unequal traffic class",
|
|
||||||
[][]byte{
|
|
||||||
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
|
|
||||||
tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
|
|
||||||
fields.TrafficClass++
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
[]int{0, 1},
|
|
||||||
[]int{160, 160},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
toWrite := make([]int, 0, len(tt.pktsIn))
|
|
||||||
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
|
|
||||||
if err != nil {
|
|
||||||
if tt.wantErr {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("got err: %v", err)
|
|
||||||
}
|
|
||||||
if len(toWrite) != len(tt.wantToWrite) {
|
|
||||||
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
|
|
||||||
}
|
|
||||||
for i, pktI := range tt.wantToWrite {
|
|
||||||
if tt.wantToWrite[i] != toWrite[i] {
|
|
||||||
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
|
|
||||||
}
|
|
||||||
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
|
|
||||||
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_isTCP4NoIPOptions(t *testing.T) {
|
|
||||||
valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
|
|
||||||
invalidLen := valid[:39]
|
|
||||||
invalidHeaderLen := make([]byte, len(valid))
|
|
||||||
copy(invalidHeaderLen, valid)
|
|
||||||
invalidHeaderLen[0] = 0x46
|
|
||||||
invalidProtocol := make([]byte, len(valid))
|
|
||||||
copy(invalidProtocol, valid)
|
|
||||||
invalidProtocol[9] = unix.IPPROTO_TCP + 1
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
b []byte
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
"valid",
|
|
||||||
valid,
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid length",
|
|
||||||
invalidLen,
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid version",
|
|
||||||
[]byte{0x00},
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid header len",
|
|
||||||
invalidHeaderLen,
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"invalid protocol",
|
|
||||||
invalidProtocol,
|
|
||||||
false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := isTCP4NoIPOptions(tt.b); got != tt.want {
|
|
||||||
t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,8 +0,0 @@
|
||||||
go test fuzz v1
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
int(34)
|
|
|
@ -1,8 +0,0 @@
|
||||||
go test fuzz v1
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
[]byte("0")
|
|
||||||
int(-48)
|
|
|
@ -17,8 +17,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/conn"
|
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||||
"github.com/amnezia-vpn/amnezia-wg/rwcancel"
|
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ type NativeTun struct {
|
||||||
statusListenersShutdown chan struct{}
|
statusListenersShutdown chan struct{}
|
||||||
batchSize int
|
batchSize int
|
||||||
vnetHdr bool
|
vnetHdr bool
|
||||||
|
udpGSO bool
|
||||||
|
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
|
|
||||||
|
@ -48,9 +49,10 @@ type NativeTun struct {
|
||||||
readOpMu sync.Mutex // readOpMu guards readBuff
|
readOpMu sync.Mutex // readOpMu guards readBuff
|
||||||
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
||||||
|
|
||||||
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
|
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
|
||||||
toWrite []int
|
toWrite []int
|
||||||
tcp4GROTable, tcp6GROTable *tcpGROTable
|
tcpGROTable *tcpGROTable
|
||||||
|
udpGROTable *udpGROTable
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) File() *os.File {
|
func (tun *NativeTun) File() *os.File {
|
||||||
|
@ -333,8 +335,8 @@ func (tun *NativeTun) nameSlow() (string, error) {
|
||||||
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
tun.writeOpMu.Lock()
|
tun.writeOpMu.Lock()
|
||||||
defer func() {
|
defer func() {
|
||||||
tun.tcp4GROTable.reset()
|
tun.tcpGROTable.reset()
|
||||||
tun.tcp6GROTable.reset()
|
tun.udpGROTable.reset()
|
||||||
tun.writeOpMu.Unlock()
|
tun.writeOpMu.Unlock()
|
||||||
}()
|
}()
|
||||||
var (
|
var (
|
||||||
|
@ -343,7 +345,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
)
|
)
|
||||||
tun.toWrite = tun.toWrite[:0]
|
tun.toWrite = tun.toWrite[:0]
|
||||||
if tun.vnetHdr {
|
if tun.vnetHdr {
|
||||||
err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
|
err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -394,37 +396,42 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e
|
||||||
sizes[0] = n
|
sizes[0] = n
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||||
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
||||||
}
|
}
|
||||||
|
|
||||||
ipVersion := in[0] >> 4
|
ipVersion := in[0] >> 4
|
||||||
switch ipVersion {
|
switch ipVersion {
|
||||||
case 4:
|
case 4:
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||||
}
|
}
|
||||||
case 6:
|
case 6:
|
||||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(in) <= int(hdr.csumStart+12) {
|
|
||||||
return 0, errors.New("packet is too short")
|
|
||||||
}
|
|
||||||
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
|
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
|
||||||
// of the entire first packet when the kernel is handling it as part of a
|
// of the entire first packet when the kernel is handling it as part of a
|
||||||
// FORWARD path. Instead, parse the TCP header length and add it onto
|
// FORWARD path. Instead, parse the transport header length and add it onto
|
||||||
// csumStart, which is synonymous for IP header length.
|
// csumStart, which is synonymous for IP header length.
|
||||||
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
|
||||||
if tcpHLen < 20 || tcpHLen > 60 {
|
hdr.hdrLen = hdr.csumStart + 8
|
||||||
// A TCP header must be between 20 and 60 bytes in length.
|
} else {
|
||||||
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
if len(in) <= int(hdr.csumStart+12) {
|
||||||
|
return 0, errors.New("packet is too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
||||||
|
if tcpHLen < 20 || tcpHLen > 60 {
|
||||||
|
// A TCP header must be between 20 and 60 bytes in length.
|
||||||
|
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||||
|
}
|
||||||
|
hdr.hdrLen = hdr.csumStart + tcpHLen
|
||||||
}
|
}
|
||||||
hdr.hdrLen = hdr.csumStart + tcpHLen
|
|
||||||
|
|
||||||
if len(in) < int(hdr.hdrLen) {
|
if len(in) < int(hdr.hdrLen) {
|
||||||
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
||||||
|
@ -438,7 +445,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e
|
||||||
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
||||||
}
|
}
|
||||||
|
|
||||||
return tcpTSO(in, hdr, bufs, sizes, offset)
|
return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||||
|
@ -497,7 +504,8 @@ func (tun *NativeTun) BatchSize() int {
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// TODO: support TSO with ECN bits
|
// TODO: support TSO with ECN bits
|
||||||
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||||
|
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
|
||||||
)
|
)
|
||||||
|
|
||||||
func (tun *NativeTun) initFromFlags(name string) error {
|
func (tun *NativeTun) initFromFlags(name string) error {
|
||||||
|
@ -519,12 +527,17 @@ func (tun *NativeTun) initFromFlags(name string) error {
|
||||||
}
|
}
|
||||||
got := ifr.Uint16()
|
got := ifr.Uint16()
|
||||||
if got&unix.IFF_VNET_HDR != 0 {
|
if got&unix.IFF_VNET_HDR != 0 {
|
||||||
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
|
// tunTCPOffloads were added in Linux v2.6. We require their support
|
||||||
|
// if IFF_VNET_HDR is set.
|
||||||
|
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tun.vnetHdr = true
|
tun.vnetHdr = true
|
||||||
tun.batchSize = conn.IdealBatchSize
|
tun.batchSize = conn.IdealBatchSize
|
||||||
|
// tunUDPOffloads were added in Linux v6.2. We do not return an
|
||||||
|
// error if they are unsupported at runtime.
|
||||||
|
tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
|
||||||
} else {
|
} else {
|
||||||
tun.batchSize = 1
|
tun.batchSize = 1
|
||||||
}
|
}
|
||||||
|
@ -575,8 +588,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||||
events: make(chan Event, 5),
|
events: make(chan Event, 5),
|
||||||
errors: make(chan error, 5),
|
errors: make(chan error, 5),
|
||||||
statusListenersShutdown: make(chan struct{}),
|
statusListenersShutdown: make(chan struct{}),
|
||||||
tcp4GROTable: newTCPGROTable(),
|
tcpGROTable: newTCPGROTable(),
|
||||||
tcp6GROTable: newTCPGROTable(),
|
udpGROTable: newUDPGROTable(),
|
||||||
toWrite: make([]int, 0, conn.IdealBatchSize),
|
toWrite: make([]int, 0, conn.IdealBatchSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -628,12 +641,12 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
|
||||||
}
|
}
|
||||||
file := os.NewFile(uintptr(fd), "/dev/tun")
|
file := os.NewFile(uintptr(fd), "/dev/tun")
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
tunFile: file,
|
tunFile: file,
|
||||||
events: make(chan Event, 5),
|
events: make(chan Event, 5),
|
||||||
errors: make(chan error, 5),
|
errors: make(chan error, 5),
|
||||||
tcp4GROTable: newTCPGROTable(),
|
tcpGROTable: newTCPGROTable(),
|
||||||
tcp6GROTable: newTCPGROTable(),
|
udpGROTable: newUDPGROTable(),
|
||||||
toWrite: make([]int, 0, conn.IdealBatchSize),
|
toWrite: make([]int, 0, conn.IdealBatchSize),
|
||||||
}
|
}
|
||||||
name, err := tun.Name()
|
name, err := tun.Name()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||||
|
|
||||||
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
|
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
|
||||||
func (tun *NativeTun) ForceMTU(mtu int) {
|
func (tun *NativeTun) ForceMTU(mtu int) {
|
||||||
|
if tun.close.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
update := tun.forcedMTU != mtu
|
update := tun.forcedMTU != mtu
|
||||||
tun.forcedMTU = mtu
|
tun.forcedMTU = mtu
|
||||||
if update {
|
if update {
|
||||||
|
@ -157,11 +160,10 @@ retry:
|
||||||
packet, err := tun.session.ReceivePacket()
|
packet, err := tun.session.ReceivePacket()
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
packetSize := len(packet)
|
n := copy(bufs[0][offset:], packet)
|
||||||
copy(bufs[0][offset:], packet)
|
sizes[0] = n
|
||||||
sizes[0] = packetSize
|
|
||||||
tun.session.ReleaseReceivePacket(packet)
|
tun.session.ReleaseReceivePacket(packet)
|
||||||
tun.rate.update(uint64(packetSize))
|
tun.rate.update(uint64(n))
|
||||||
return 1, nil
|
return 1, nil
|
||||||
case windows.ERROR_NO_MORE_ITEMS:
|
case windows.ERROR_NO_MORE_ITEMS:
|
||||||
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amnezia-wg/tun"
|
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Ping(dst, src netip.Addr) []byte {
|
func Ping(dst, src netip.Addr) []byte {
|
||||||
|
|
Loading…
Add table
Reference in a new issue