Compare commits

..

No commits in common. "master" and "0.0.20220316" have entirely different histories.

113 changed files with 3013 additions and 6035 deletions

View file

@ -1,41 +0,0 @@
name: build-if-tag
on:
push:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+'
env:
APP: amneziawg-go
jobs:
build:
runs-on: ubuntu-latest
name: build
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Setup metadata
uses: docker/metadata-action@v5
id: metadata
with:
images: amneziavpn/${{ env.APP }}
tags: type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
push: true
tags: ${{ steps.metadata.outputs.tags }}

2
.gitignore vendored
View file

@ -1 +1 @@
amneziawg-go
wireguard-go

View file

@ -1,17 +0,0 @@
FROM golang:1.24 as awg
COPY . /awg
WORKDIR /awg
RUN go mod download && \
go mod verify && \
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
FROM alpine:3.19
ARG AWGTOOLS_RELEASE="1.0.20241018"
RUN apk --no-cache add iproute2 iptables bash && \
cd /usr/bin/ && \
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
unzip -j alpine-3.19-amneziawg-tools.zip && \
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
ln -s /usr/bin/awg /usr/bin/wg && \
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go

View file

@ -9,23 +9,23 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \
git update-index --assume-unchanged version.go || true
@$(MAKE) amneziawg-go
@$(MAKE) wireguard-go
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
go build -v -o "$@"
install: amneziawg-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
test:
go test ./...
clean:
rm -f amneziawg-go
rm -f wireguard-go
.PHONY: all clean test install generate-version-and-build

View file

@ -1,27 +1,24 @@
# Go Implementation of AmneziaWG
# Go Implementation of [WireGuard](https://www.wireguard.com/)
AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original.
The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
This is an implementation of WireGuard in Go.
## Usage
Simply run:
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
```
$ amneziawg-go wg0
$ wireguard-go wg0
```
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
To run wireguard-go without forking to the background, pass `-f` or `--foreground`:
```
$ amneziawg-go -f wg0
$ wireguard-go -f wg0
```
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
@ -29,24 +26,52 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
### Linux
This will run on Linux; you should run amnezia-wg instead of using default linux kernel module.
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
### macOS
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
### Windows
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
### FreeBSD
This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
### OpenBSD
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
## Building
This requires an installation of the latest version of [Go](https://go.dev/).
This requires an installation of [go](https://golang.org) ≥ 1.18.
```
$ git clone https://github.com/amnezia-vpn/amneziawg-go
$ cd amneziawg-go
$ git clone https://git.zx2c4.com/wireguard-go
$ cd wireguard-go
$ make
```
## License
Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

562
conn/bind_linux.go Normal file
View file

@ -0,0 +1,562 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"net"
"net/netip"
"strconv"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
type ipv4Source struct {
Src [4]byte
Ifindex int32
}
type ipv6Source struct {
src [16]byte
// ifindex belongs in dst.ZoneId
}
type LinuxSocketEndpoint struct {
mu sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(ipv6Source{})]byte
isV6 bool
}
func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
}
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
type LinuxSocketBind struct {
// mu guards sock4 and sock6 and the associated fds.
// As long as someone holds mu (read or write), the associated fds are valid.
mu sync.RWMutex
sock4 int
sock6 int
}
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
func NewDefaultBind() Bind { return NewLinuxSocketBind() }
var (
_ Endpoint = (*LinuxSocketEndpoint)(nil)
_ Bind = (*LinuxSocketBind)(nil)
)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
if e.Addr().Is4() {
dst := end.dst4()
end.isV6 = false
dst.Port = int(e.Port())
dst.Addr = e.Addr().As4()
end.ClearSrc()
return &end, nil
}
if e.Addr().Is6() {
zone, err := zoneToUint32(e.Addr().Zone())
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
dst.Port = int(e.Port())
dst.ZoneId = zone
dst.Addr = e.Addr().As16()
end.ClearSrc()
return &end, nil
}
return nil, errors.New("invalid IP address")
}
func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error
var newPort uint16
var tries int
if bind.sock4 != -1 || bind.sock6 != -1 {
return nil, 0, ErrBindAlreadyOpen
}
originalPort := port
again:
port = originalPort
var sock4, sock6 int
// Attempt ipv6 bind, update port if successful.
sock6, newPort, err = create6(port)
if err != nil {
if !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
} else {
port = newPort
}
// Attempt ipv4 bind, update port if successful.
sock4, newPort, err = create4(port)
if err != nil {
if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
unix.Close(sock6)
tries++
goto again
}
if !errors.Is(err, syscall.EAFNOSUPPORT) {
unix.Close(sock6)
return nil, 0, err
}
} else {
port = newPort
}
var fns []ReceiveFunc
if sock4 != -1 {
bind.sock4 = sock4
fns = append(fns, bind.receiveIPv4)
}
if sock6 != -1 {
bind.sock6 = sock6
fns = append(fns, bind.receiveIPv6)
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, port, nil
}
func (bind *LinuxSocketBind) SetMark(value uint32) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
if bind.sock4 != -1 {
err := unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
return nil
}
func (bind *LinuxSocketBind) Close() error {
// Take a readlock to shut down the sockets...
bind.mu.RLock()
if bind.sock6 != -1 {
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
}
if bind.sock4 != -1 {
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
}
bind.mu.RUnlock()
// ...and a write lock to close the fd.
// This ensures that no one else is using the fd.
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error
if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6)
bind.sock6 = -1
}
if bind.sock4 != -1 {
err2 = unix.Close(bind.sock4)
bind.sock4 = -1
}
if err1 != nil {
return err1
}
return err2
}
func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock4 == -1 {
return 0, nil, net.ErrClosed
}
var end LinuxSocketEndpoint
n, err := receive4(bind.sock4, buf, &end)
return n, &end, err
}
func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock6 == -1 {
return 0, nil, net.ErrClosed
}
var end LinuxSocketEndpoint
n, err := receive6(bind.sock6, buf, &end)
return n, &end, err
}
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
nend, ok := end.(*LinuxSocketEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
if !nend.isV6 {
if bind.sock4 == -1 {
return net.ErrClosed
}
return send4(bind.sock4, nend, buff)
} else {
if bind.sock6 == -1 {
return net.ErrClosed
}
return send6(bind.sock6, nend, buff)
}
}
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 {
return netip.AddrFrom4(end.src4().Src)
} else {
return netip.AddrFrom16(end.src6().src)
}
}
func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 {
return netip.AddrFrom4(end.dst4().Addr)
} else {
return netip.AddrFrom16(end.dst6().Addr)
}
}
func (end *LinuxSocketEndpoint) DstToBytes() []byte {
if !end.isV6 {
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
} else {
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
}
}
func (end *LinuxSocketEndpoint) SrcToString() string {
return end.SrcIP().String()
}
func (end *LinuxSocketEndpoint) DstToString() string {
var port int
if !end.isV6 {
port = end.dst4().Port
} else {
port = end.dst6().Port
}
return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
}
func (end *LinuxSocketEndpoint) ClearDst() {
for i := range end.dst {
end.dst[i] = 0
}
}
func (end *LinuxSocketEndpoint) ClearSrc() {
for i := range end.src {
end.src[i] = 0
}
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
}
if intr, err := net.InterfaceByName(zone); err == nil {
return uint32(intr.Index), nil
}
n, err := strconv.ParseUint(zone, 10, 32)
return uint32(n), err
}
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return -1, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet4).Port
}
return fd, uint16(addr.Port), err
}
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return -1, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet6).Port
}
return fd, uint16(addr.Port), err
}
func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
Spec_dst: end.src4().Src,
Ifindex: end.src4().Ifindex,
},
}
end.mu.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.mu.Unlock()
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
end.mu.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.mu.Unlock()
}
return err
}
func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.src6().src,
Ifindex: end.dst6().ZoneId,
},
}
if cmsg.pktinfo.Addr == [16]byte{} {
cmsg.pktinfo.Ifindex = 0
}
end.mu.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.mu.Unlock()
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
end.mu.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.mu.Unlock()
}
return err
}
func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = false
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
*end.dst4() = *newDst4
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4().Src = cmsg.pktinfo.Spec_dst
end.src4().Ifindex = cmsg.pktinfo.Ifindex
}
return size, nil
}
func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = true
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
*end.dst6() = *newDst6
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src6().src = cmsg.pktinfo.Addr
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
return size, nil
}

View file

@ -1,126 +1,80 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var (
_ Bind = (*StdNetBind)(nil)
)
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
// StdNetBind is meant to be a temporary solution on platforms for which
// the sticky socket / source caching behavior has not yet been implemented.
// It uses the Go's net package to implement networking.
// See LinuxSocketBind for a proper implementation on the Linux platform.
type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these two fields are not guarded by mu
udpAddrPool sync.Pool
msgsPool sync.Pool
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
}
func NewStdNetBind() Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
func NewStdNetBind() Bind { return &StdNetBind{} }
msgsPool: sync.Pool{
New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
}
return &msgs
},
},
}
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if
// supported. Typically this is a PKTINFO structure from/for control
// messages, see unix.PKTINFO for an example.
src []byte
}
type StdNetEndpoint net.UDPAddr
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = &StdNetEndpoint{}
_ Endpoint = (*StdNetEndpoint)(nil)
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &StdNetEndpoint{
AddrPort: e,
}, nil
return (*StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), err
}
func (e *StdNetEndpoint) ClearSrc() {
if e.src != nil {
// Truncate src, no need to reallocate.
e.src = e.src[:0]
}
}
func (*StdNetEndpoint) ClearSrc() {}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
return a
}
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
addr := (*net.UDPAddr)(e)
out := addr.IP.To4()
if out == nil {
out = addr.IP
}
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
}
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
return (*net.UDPAddr)(e).String()
}
func (e *StdNetEndpoint) SrcToString() string {
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
@ -134,17 +88,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
return conn, uaddr.Port, nil
}
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error
var tries int
if s.ipv4 != nil || s.ipv6 != nil {
if bind.ipv4 != nil || bind.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@ -152,207 +106,92 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
var ipv4, ipv6 *net.UDPConn
v4conn, port, err = listenNet("udp4", port)
ipv4, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
ipv6, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
ipv4.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
ipv4.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if v4conn != nil {
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
if ipv4 != nil {
fns = append(fns, bind.makeReceiveIPv4(ipv4))
bind.ipv4 = ipv4
}
if v6conn != nil {
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
if ipv6 != nil {
fns = append(fns, bind.makeReceiveIPv6(ipv6))
bind.ipv6 = ipv6
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
}
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
s.msgsPool.Put(msgs)
}
func (s *StdNetBind) getMessages() *[]ipv6.Message {
return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize
}
return 1
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
func (bind *StdNetBind) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
s.ipv4PC = nil
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
bind.ipv4 = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
s.ipv6PC = nil
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
bind.ipv6 = nil
}
s.blackhole4 = false
s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
bind.blackhole4 = false
bind.blackhole6 = false
if err1 != nil {
return err1
}
return err2
}
type ErrUDPGSODisabled struct {
onLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
offload := s.ipv4TxOffload
br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
br = s.ipv6PC
is6 = true
offload = s.ipv6TxOffload
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDP(buff)
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*StdNetEndpoint)(endpoint), err
}
s.mu.Unlock()
}
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDP(buff)
return n, (*StdNetEndpoint)(endpoint), err
}
}
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend, ok := endpoint.(*StdNetEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.Lock()
blackhole := bind.blackhole4
conn := bind.ipv4
if nend.IP.To4() == nil {
blackhole = bind.blackhole6
conn = bind.ipv6
}
bind.mu.Unlock()
if blackhole {
return nil
@ -360,185 +199,6 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 {
as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else {
as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
}
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
_, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
return err
}
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
start int
)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for {
n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
return err
}
const (
// Exceeding these values results in EMSGSIZE. They account for layer3 and
// layer4 headers. IPv6 does not need to account for itself as the payload
// length field is self excluding.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
// This is a hard limit imposed by the kernel.
udpSegmentMaxDatagrams = 64
)
type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
maxPayloadLen := maxIPv4PayloadLen
if ep.DstIP().Is6() {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buf := range bufs {
if i > 0 {
msgLen := len(buf)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}

View file

@ -1,250 +0,0 @@
package conn
import (
"encoding/binary"
"net"
"testing"
"golang.org/x/net/ipv6"
)
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
fns, _, err := bind.Open(0)
if err != nil {
t.Fatal(err)
}
bind.Close()
bufs := make([][]byte, 1)
bufs[0] = make([]byte, 1)
sizes := make([]int, 1)
eps := make([]Endpoint, 1)
for _, fn := range fns {
// The ReceiveFuncs must not access conn-related fields on StdNetBind
// unguarded. Close() nils the conn-related fields resulting in a panic
// if they violate the mutex.
fn(bufs, sizes, eps)
}
}
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func Test_coalesceMessages(t *testing.T) {
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1").To4(),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := 0; i < got; i++ {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}
func mockGetGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_splitCoalescedMessages(t *testing.T) {
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1<<16-1)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -17,7 +17,7 @@ import (
"golang.org/x/sys/windows"
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
"golang.zx2c4.com/wireguard/conn/winrio"
)
const (
@ -74,7 +74,7 @@ type afWinRingBind struct {
type WinRingBind struct {
v4, v6 afWinRingBind
mu sync.RWMutex
isOpen atomic.Uint32 // 0, 1, or 2
isOpen uint32
}
func NewDefaultBind() Bind { return NewWinRingBind() }
@ -164,7 +164,7 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
}
func (bind *WinRingBind) closeAndZero() {
bind.isOpen.Store(0)
atomic.StoreUint32(&bind.isOpen, 0)
bind.v4.CloseAndZero()
bind.v6.CloseAndZero()
}
@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
bind.closeAndZero()
}
}()
if bind.isOpen.Load() != 0 {
if atomic.LoadUint32(&bind.isOpen) != 0 {
return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
return nil, 0, err
}
}
bind.isOpen.Store(1)
atomic.StoreUint32(&bind.isOpen, 1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}
func (bind *WinRingBind) Close() error {
bind.mu.RLock()
if bind.isOpen.Load() != 1 {
if atomic.LoadUint32(&bind.isOpen) != 1 {
bind.mu.RUnlock()
return nil
}
bind.isOpen.Store(2)
atomic.StoreUint32(&bind.isOpen, 2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
@ -321,13 +321,6 @@ func (bind *WinRingBind) Close() error {
return nil
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (bind *WinRingBind) BatchSize() int {
// TODO: implement batching in and out of the ring
return 1
}
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
@ -352,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
if isOpen.Load() != 1 {
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
if atomic.LoadUint32(isOpen) != 1 {
return 0, nil, net.ErrClosed
}
bind.rx.mu.Lock()
@ -366,7 +359,7 @@ retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
if isOpen.Load() != 1 {
if atomic.LoadUint32(isOpen) != 1 {
return 0, nil, net.ErrClosed
}
procyield(1)
@ -385,7 +378,7 @@ retry:
if err != nil {
return 0, nil, err
}
if isOpen.Load() != 1 {
if atomic.LoadUint32(isOpen) != 1 {
return 0, nil, net.ErrClosed
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
@ -402,7 +395,7 @@ retry:
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if isOpen.Load() != 1 {
if atomic.LoadUint32(isOpen) != 1 {
return 0, nil, net.ErrClosed
}
goto retry
@ -416,26 +409,20 @@ retry:
return n, &ep, nil
}
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
return bind.v4.Receive(buf, &bind.isOpen)
}
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
return bind.v6.Receive(buf, &bind.isOpen)
}
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
if isOpen.Load() != 1 {
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
if atomic.LoadUint32(isOpen) != 1 {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
@ -457,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
if err != nil {
return err
}
if isOpen.Load() != 1 {
if atomic.LoadUint32(isOpen) != 1 {
return net.ErrClosed
}
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
@ -486,38 +473,32 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
for _, buf := range bufs {
switch nend.family {
case windows.AF_INET:
if bind.v4.blackhole {
continue
}
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
case windows.AF_INET6:
if bind.v6.blackhole {
continue
}
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
switch nend.family {
case windows.AF_INET:
if bind.v4.blackhole {
return nil
}
return bind.v4.Send(buf, nend, &bind.isOpen)
case windows.AF_INET6:
if bind.v6.blackhole {
return nil
}
return bind.v6.Send(buf, nend, &bind.isOpen)
}
return nil
}
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
s.mu.Lock()
defer s.mu.Unlock()
sysconn, err := s.ipv4.SyscallConn()
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock()
defer bind.mu.Unlock()
sysconn, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
@ -530,14 +511,14 @@ func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole boo
if err != nil {
return err
}
s.blackhole4 = blackhole
bind.blackhole4 = blackhole
return nil
}
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
s.mu.Lock()
defer s.mu.Unlock()
sysconn, err := s.ipv6.SyscallConn()
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock()
defer bind.mu.Unlock()
sysconn, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}
@ -550,14 +531,14 @@ func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole boo
if err != nil {
return err
}
s.blackhole6 = blackhole
bind.blackhole6 = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
if atomic.LoadUint32(&bind.isOpen) != 1 {
return net.ErrClosed
}
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
@ -571,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
if atomic.LoadUint32(&bind.isOpen) != 1 {
return net.ErrClosed
}
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package bindtest
@ -12,7 +12,7 @@ import (
"net/netip"
"os"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/conn"
)
type ChannelBind struct {
@ -89,39 +89,32 @@ func (c *ChannelBind) Close() error {
return nil
}
func (c *ChannelBind) BatchSize() int { return 1 }
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
return func(b []byte) (n int, ep conn.Endpoint, err error) {
select {
case <-c.closeSignal:
return 0, net.ErrClosed
return 0, nil, net.ErrClosed
case rx := <-ch:
copied := copy(bufs[0], rx)
sizes[0] = copied
eps[0] = c.target6
return 1, nil
return copy(b, rx), c.target6, nil
}
}
}
func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
for _, b := range bufs {
select {
case <-c.closeSignal:
return net.ErrClosed
default:
bc := make([]byte, len(b))
copy(bc, b)
if ep.(ChannelEndpoint) == c.target4 {
*c.tx4 <- bc
} else if ep.(ChannelEndpoint) == c.target6 {
*c.tx6 <- bc
} else {
return os.ErrInvalid
}
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
select {
case <-c.closeSignal:
return net.ErrClosed
default:
bc := make([]byte, len(b))
copy(bc, b)
if ep.(ChannelEndpoint) == c.target4 {
*c.tx4 <- bc
} else if ep.(ChannelEndpoint) == c.target6 {
*c.tx6 <- bc
} else {
return os.ErrInvalid
}
}
return nil

View file

@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := s.ipv4.SyscallConn()
func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := bind.ipv4.SyscallConn()
if err != nil {
return -1, err
}
@ -19,8 +19,8 @@ func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
return
}
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := s.ipv6.SyscallConn()
func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := bind.ipv6.SyscallConn()
if err != nil {
return -1, err
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
@ -15,17 +15,10 @@ import (
"strings"
)
const (
IdealBatchSize = 128 // maximum number of packets handled per read and write
)
// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A ReceiveFunc receives a single inbound packet from the network.
// It writes the data into b. n is the length of the packet.
// ep is the remote endpoint.
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
@ -45,16 +38,11 @@ type Bind interface {
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
// Send writes one or more packets in bufs to address ep. The length of
// bufs must not exceed BatchSize().
Send(bufs [][]byte, ep Endpoint) error
// Send writes a packet b to address ep.
Send(b []byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
// BatchSize is the number of buffers expected to be passed to
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being

View file

@ -1,24 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"testing"
)
func TestPrettyName(t *testing.T) {
var (
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
)
const want = "TestPrettyName"
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
if got := recvFunc.PrettyName(); got != want {
t.Errorf("PrettyName() = %v, want %v", got, want)
}
})
}

View file

@ -1,43 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"syscall"
)
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
// the max supported by a default configuration of macOS. Some platforms will
// silently clamp the value to other maximums, such as linux clamping to
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
// around this limitation)
const socketBufferSize = 7 << 20
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
}

View file

@ -1,61 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
// fail silently - the result of failure is lower performance on very fast
// links or high latency links.
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
// Set up to *mem_max
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
// Set beyond *mem_max if CAP_NET_ADMIN
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
})
},
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
if runtime.GOOS != "android" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
}
case "udp6":
c.Control(func(fd uintptr) {
if runtime.GOOS != "android" {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

View file

@ -1,35 +0,0 @@
//go:build !windows && !linux && !wasm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
})
},
func(network, address string, c syscall.RawConn) error {
var err error
if network == "udp6" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
}
return err
},
)
}

View file

@ -1,23 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/windows"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
})
},
)
}

View file

@ -1,8 +1,8 @@
//go:build !windows
//go:build !linux && !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package conn

View file

@ -1,12 +0,0 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(err error) bool {
return false
}

View file

@ -1,28 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func errShouldDisableUDPGSO(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
// It occurs when the peer mtu + wg headers is greater than path mtu.
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
}
return false
}

View file

@ -1,15 +0,0 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
return
}

View file

@ -1,31 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"golang.org/x/sys/unix"
)
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
rc, err := conn.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = errSyscall == nil
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
// use setsockopt workaround
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
rxOffload = errSyscall == nil
})
if err != nil {
return false, false
}
return txOffload, rxOffload
}

View file

@ -1,21 +0,0 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
func setGSOSize(control *[]byte, gsoSize uint16) {
}
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
// offloading control data.
const gsoControlSize = 0

View file

@ -1,65 +0,0 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"unsafe"
"golang.org/x/sys/unix"
)
const (
sizeOfGSOData = 2
)
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
var gso uint16
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
// data in control untouched.
func setGSOSize(control *[]byte, gsoSize uint16) {
existingLen := len(*control)
avail := cap(*control) - existingLen
space := unix.CmsgSpace(sizeOfGSOData)
if avail < space {
return
}
*control = (*control)[:cap(*control)]
gsoControl := (*control)[existingLen:]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
*control = (*control)[:existingLen+space]
}
// gsoControlSize returns the recommended buffer size for pooling UDP
// offloading control data.
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

View file

@ -2,11 +2,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
func (s *StdNetBind) SetMark(mark uint32) error {
func (bind *StdNetBind) SetMark(mark uint32) error {
return nil
}

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
@ -26,13 +26,13 @@ func init() {
}
}
func (s *StdNetBind) SetMark(mark uint32) error {
func (bind *StdNetBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
if s.ipv4 != nil {
fd, err := s.ipv4.SyscallConn()
if bind.ipv4 != nil {
fd, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
@ -46,8 +46,8 @@ func (s *StdNetBind) SetMark(mark uint32) error {
return err
}
}
if s.ipv6 != nil {
fd, err := s.ipv6.SyscallConn()
if bind.ipv6 != nil {
fd, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}

View file

@ -1,42 +0,0 @@
//go:build !linux || android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net/netip"
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return ""
}
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
// {get,set}srcControl feature set, but use alternatively named flags and need
// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
const stickyControlSize = 0
const StdNetSupportsStickySockets = false

View file

@ -1,112 +0,0 @@
//go:build linux && !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func (e *StdNetEndpoint) SrcIP() netip.Addr {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return netip.AddrFrom4(info.Spec_dst)
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
// TODO: set zone. in order to do so we need to check if the address is
// link local, and if it is perform a syscall to turn the ifindex into a
// zone string because netip uses string zones.
return netip.AddrFrom16(info.Addr)
}
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return info.Ifindex
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return int32(info.Ifindex)
}
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return e.SrcIP().String()
}
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
}
}
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
if cap(*control) < len(ep.src) {
return
}
*control = (*control)[:0]
*control = append(*control, ep.src...)
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true

View file

@ -1,266 +0,0 @@
//go:build linux && !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"net"
"net/netip"
"runtime"
"testing"
"unsafe"
"golang.org/x/sys/unix"
)
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
var buf []byte
if addr.Is4() {
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
}
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
info := unix.Inet4Pktinfo{
Ifindex: ifidx,
Spec_dst: addr.As4(),
}
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
} else {
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
hdr := unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
}
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
info := unix.Inet6Pktinfo{
Ifindex: uint32(ifidx),
Addr: addr.As16(),
}
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
}
ep.src = buf
}
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IP {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IP_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
t.Errorf("unexpected address: %v", info.Spec_dst)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("IPv6", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
setSrc(ep, netip.MustParseAddr("::1"), 5)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IPV6 {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IPV6_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Addr != ep.SrcIP().As16() {
t.Errorf("unexpected address: %v", info.Addr)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
hdr.Len = 3
setSrcControl(&control, &StdNetEndpoint{})
if len(control) != 0 {
t.Errorf("unexpected control: %v", control)
}
})
}
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
var control []byte
ep := &StdNetEndpoint{}
setSrc(ep, netip.MustParseAddr("::1"), 5)
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 0 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("Multiple", func(t *testing.T) {
zeroControl := make([]byte, unix.CmsgSpace(0))
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
zeroHdr.SetLen(unix.CmsgLen(0))
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
combined := make([]byte, 0)
combined = append(combined, zeroControl...)
combined = append(combined, control...)
ep := &StdNetEndpoint{}
getSrcFromControl(combined, ep)
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
}
func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IP_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IPV6_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
*/
package winrio

65
device/alignment_test.go Normal file
View file

@ -0,0 +1,65 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"reflect"
"testing"
"unsafe"
)
func checkAlignment(t *testing.T, name string, offset uintptr) {
t.Helper()
if offset%8 != 0 {
t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
}
}
// TestPeerAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
// Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime.
func TestPeerAlignment(t *testing.T) {
var p Peer
typ := reflect.TypeOf(&p).Elem()
t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
field.Name,
field.Offset,
field.Type.Size(),
field.Type.Align(),
)
}
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
}
// TestDeviceAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
// Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime.
func TestDeviceAlignment(t *testing.T) {
var d Device
typ := reflect.TypeOf(&d).Elem()
t.Logf("Device type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
field.Name,
field.Offset,
field.Type.Size(),
field.Type.Align(),
)
}
checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -8,7 +8,7 @@ package device
import (
"errors"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/conn"
)
type DummyDatagram struct {
@ -26,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil
}
func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
}
copy(buf, datagram.msg)
copy(buff, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
}
copy(buf, datagram.msg)
copy(buff, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
@ -51,6 +51,6 @@ func (b *DummyBind) Close() error {
return nil
}
func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
return nil
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -19,13 +19,13 @@ import (
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
c chan *QueueOutboundElementsContainer
c chan *QueueOutboundElement
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
c: make(chan *QueueOutboundElement, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
c chan *QueueInboundElementsContainer
c chan *QueueInboundElement
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
c: make(chan *QueueInboundElement, QueueInboundSize),
}
q.wg.Add(1)
go func() {
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
c chan *QueueInboundElementsContainer
c chan *QueueInboundElement
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
c: make(chan *QueueInboundElement, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@ -90,13 +90,10 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
case elem := <-q.c:
elem.Lock()
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
default:
return
}
@ -104,7 +101,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
c chan *QueueOutboundElementsContainer
c chan *QueueOutboundElement
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@ -114,7 +111,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
c: make(chan *QueueOutboundElement, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@ -123,13 +120,10 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
case elem := <-q.c:
elem.Lock()
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
default:
return
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -11,12 +11,10 @@ import (
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/tevino/abool/v2"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
)
type Device struct {
@ -32,7 +30,7 @@ type Device struct {
// will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only.
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
state uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
// mu protects state changes.
@ -60,8 +58,9 @@ type Device struct {
keyMap map[NoisePublicKey]*Peer
}
// Keep this 8-byte aligned
rate struct {
underLoadUntil atomic.Int64
underLoadUntil int64
limiter ratelimiter.Ratelimiter
}
@ -70,11 +69,9 @@ type Device struct {
cookieChecker CookieChecker
pool struct {
inboundElementsContainer *WaitPool
outboundElementsContainer *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
}
queue struct {
@ -85,39 +82,22 @@ type Device struct {
tun struct {
device tun.Device
mtu atomic.Int32
mtu int32
}
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
}
type aSecCfgType struct {
isSet bool
junkPacketCount int
junkPacketMinSize int
junkPacketMaxSize int
initPacketJunkSize int
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
}
// deviceState represents the state of a Device.
// There are three states: down, up, closed.
// Transitions:
//
// down -----+
// ↑↓ ↓
// up -> closed
// down -----+
// ↑↓ ↓
// up -> closed
//
type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@ -130,7 +110,7 @@ const (
// deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState {
return deviceState(device.state.state.Load())
return deviceState(atomic.LoadUint32(&device.state.state))
}
// isClosed reports whether the device is closed (or is closing).
@ -169,21 +149,20 @@ func (device *Device) changeState(want deviceState) (err error) {
case old:
return nil
case deviceStateUp:
device.state.state.Store(uint32(deviceStateUp))
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
err = device.upLocked()
if err == nil {
break
}
fallthrough // up failed; bring the device all the way back down
case deviceStateDown:
device.state.state.Store(uint32(deviceStateDown))
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
errDown := device.downLocked()
if err == nil {
err = errDown
}
}
device.log.Verbosef(
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
return
}
@ -203,7 +182,7 @@ func (device *Device) upLocked() error {
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
if peer.persistentKeepaliveInterval.Load() > 0 {
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
peer.SendKeepalive()
}
}
@ -240,11 +219,11 @@ func (device *Device) IsUnderLoad() bool {
now := time.Now()
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad {
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
return true
}
// check if recently under load
return device.rate.underLoadUntil.Load() > now.UnixNano()
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@ -288,7 +267,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
@ -304,7 +283,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
device.state.state.Store(uint32(deviceStateDown))
device.state.state = uint32(deviceStateDown)
device.closed = make(chan struct{})
device.log = logger
device.net.bind = bind
@ -314,11 +293,10 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
device.tun.mtu.Store(int32(mtu))
device.tun.mtu = int32(mtu)
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
device.PopulatePools()
// create queues
@ -346,19 +324,6 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device
}
// BatchSize returns the BatchSize for the device as a whole which is the max of
// the bind batch size and the tun batch size. The batch size reported by device
// is the size used to construct memory pools, and is the allowed batch size for
// the lifetime of the device.
func (device *Device) BatchSize() int {
size := device.net.bind.BatchSize()
dSize := device.tun.device.BatchSize()
if size < dSize {
size = dSize
}
return size
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@ -391,12 +356,10 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.state.Lock()
defer device.state.Unlock()
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
if device.isClosed() {
return
}
device.state.state.Store(uint32(deviceStateClosed))
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
device.log.Verbosef("Device closing")
device.tun.device.Close()
@ -416,8 +379,6 @@ func (device *Device) Close() {
device.rate.limiter.Close()
device.resetProtocol()
device.log.Verbosef("Device closed")
close(device.closed)
}
@ -484,7 +445,11 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.markEndpointSrcForClearing()
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
@ -509,13 +474,11 @@ func (device *Device) BindUpdate() error {
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
@ -534,7 +497,11 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.markEndpointSrcForClearing()
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
@ -542,9 +509,8 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
go device.RoutineReceiveIncoming(batchSize, fn)
go device.RoutineReceiveIncoming(fn)
}
device.log.Verbosef("UDP bind has been updated")
@ -557,251 +523,3 @@ func (device *Device) BindClose() error {
device.net.Unlock()
return err
}
func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
return err
}
isASecOn := false
device.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
isASecOn = true
}
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
)
}
} else {
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
if tempASecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = 1
}
if tempASecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = 2
}
if tempASecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = 3
}
if tempASecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = 4
}
isSameMap := map[uint32]bool{}
isSameMap[MessageInitiationType] = true
isSameMap[MessageResponseType] = true
isSameMap[MessageCookieReplyType] = true
isSameMap[MessageTransportType] = true
// size will be different if same values
if len(isSameMap) != 4 {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
)
}
}
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ; %w`,
newInitSize,
newResponseSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ`,
newInitSize,
newResponseSize,
)
}
} else {
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
return err
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -12,7 +12,6 @@ import (
"io"
"math/rand"
"net/netip"
"os"
"runtime"
"runtime/pprof"
"sync"
@ -20,10 +19,9 @@ import (
"testing"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
@ -91,65 +89,6 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
return
}
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
_, err = rand.Read(key2[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
)
return
}
// A testPair is a pair of testPeers.
type testPair [2]testPeer
@ -174,11 +113,7 @@ func (d SendDirection) String() string {
return "pong"
}
func (pair *testPair) Send(
tb testing.TB,
ping SendDirection,
done chan struct{},
) {
func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
tb.Helper()
p0, p1 := pair[0], pair[1]
if !ping {
@ -212,16 +147,8 @@ func (pair *testPair) Send(
}
// genTestPair creates a testPair.
func genTestPair(
tb testing.TB,
realSocket, withASecurity bool,
) (pair testPair) {
var cfg, endpointCfg [2]string
if withASecurity {
cfg, endpointCfg = genASecurityConfigs(tb)
} else {
cfg, endpointCfg = genConfigs(tb)
}
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
cfg, endpointCfg := genConfigs(tb)
var binds [2]conn.Bind
if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
@ -265,18 +192,7 @@ func genTestPair(
func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, false)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
pair := genTestPair(t, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
@ -291,7 +207,7 @@ func TestUpDown(t *testing.T) {
const otrials = 10
for n := 0; n < otrials; n++ {
pair := genTestPair(t, false, false)
pair := genTestPair(t, false)
for i := range pair {
for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
@ -325,7 +241,7 @@ func TestUpDown(t *testing.T) {
// TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t, true, false)
pair := genTestPair(t, true)
done := make(chan struct{})
const warmupIters = 10
@ -391,22 +307,11 @@ func TestConcurrencySafety(t *testing.T) {
}
})
// Perform bind updates and keepalive sends concurrently with tunnel use.
t.Run("bindUpdate and keepalive", func(t *testing.T) {
const iters = 10
for i := 0; i < iters; i++ {
for _, peer := range pair {
peer.dev.BindUpdate()
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
}
}
})
close(done)
}
func BenchmarkLatency(b *testing.B) {
pair := genTestPair(b, true, false)
pair := genTestPair(b, true)
// Establish a connection.
pair.Send(b, Ping, nil)
@ -420,7 +325,7 @@ func BenchmarkLatency(b *testing.B) {
}
func BenchmarkThroughput(b *testing.B) {
pair := genTestPair(b, true, false)
pair := genTestPair(b, true)
// Establish a connection.
pair.Send(b, Ping, nil)
@ -428,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {
// Measure how long it takes to receive b.N packets,
// starting when we receive the first packet.
var recv atomic.Uint64
var recv uint64
var elapsed time.Duration
var wg sync.WaitGroup
wg.Add(1)
@ -437,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
var start time.Time
for {
<-pair[0].tun.Inbound
new := recv.Add(1)
new := atomic.AddUint64(&recv, 1)
if new == 1 {
start = time.Now()
}
@ -453,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound
var sent uint64
for recv.Load() != uint64(b.N) {
for atomic.LoadUint64(&recv) != uint64(b.N) {
sent++
pingc <- ping
}
@ -464,7 +369,7 @@ func BenchmarkThroughput(b *testing.B) {
}
func BenchmarkUAPIGet(b *testing.B) {
pair := genTestPair(b, true, false)
pair := genTestPair(b, true)
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ReportAllocs()
@ -500,73 +405,3 @@ func goroutineLeakCheck(t *testing.T) {
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
})
}
type fakeBindSized struct {
size int
}
func (b *fakeBindSized) Open(
port uint16,
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil
}
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
func (b *fakeBindSized) BatchSize() int { return b.size }
type fakeTUNDeviceSized struct {
size int
}
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func TestBatchSize(t *testing.T) {
d := Device{}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 1, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,69 +0,0 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
device *Device
cha8Rand *v2.ChaCha8
}
func NewJunkCreator(d *Device) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return nil, fmt.Errorf("Failed to create junk packet: %v", err)
}
junks = append(junks, junk)
}
return junks, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

View file

@ -1,124 +0,0 @@
package device
import (
"bytes"
"fmt"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
cfg, _ := genASecurityConfigs(t)
tun := tuntest.NewChannelTUN()
binds := bindtest.NewChannelBinds()
level := LogLevelVerbose
dev := NewDevice(
tun.TUN(),
binds[0],
NewLogger(level, ""),
)
if err := dev.IpcSet(cfg[0]); err != nil {
t.Errorf("failed to configure device %v", err)
dev.Close()
return junkCreator{}, err
}
jc, err := NewJunkCreator(dev)
if err != nil {
t.Errorf("failed to create junk creator %v", err)
dev.Close()
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
got, err := jc.createJunkPackets()
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
jc.device.Close()
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -10,8 +10,9 @@ import (
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/amnezia-vpn/amneziawg-go/replay"
"golang.zx2c4.com/wireguard/replay"
)
/* Due to limitations in Go and /x/crypto there is currently
@ -22,7 +23,7 @@ import (
*/
type Keypair struct {
sendNonce atomic.Uint64
sendNonce uint64 // accessed atomically
send cipher.AEAD
receive cipher.AEAD
replayFilter replay.Filter
@ -36,7 +37,15 @@ type Keypairs struct {
sync.RWMutex
current *Keypair
previous *Keypair
next atomic.Pointer[Keypair]
next *Keypair
}
func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}
func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
}
func (kp *Keypairs) Current() *Keypair {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

41
device/misc.go Normal file
View file

@ -0,0 +1,41 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync/atomic"
)
/* Atomic Boolean */
const (
AtomicFalse = int32(iota)
AtomicTrue
)
type AtomicBool struct {
int32
}
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.int32) == AtomicTrue
}
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
atomic.StoreInt32(&a.int32, flag)
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.endpoint.Lock()
peer.endpoint.disableRoaming = peer.endpoint.val != nil
peer.endpoint.Unlock()
peer.Lock()
peer.disableRoaming = peer.endpoint != nil
peer.Unlock()
}
device.peers.RUnlock()
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -9,7 +9,6 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
"errors"
"hash"
"golang.org/x/crypto/blake2s"
@ -95,14 +94,9 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return
}
var errInvalidPublicKey = errors.New("invalid public key")
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
if isZero(ss[:]) {
return ss, errInvalidPublicKey
}
return ss, nil
return ss
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -15,7 +15,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"github.com/amnezia-vpn/amneziawg-go/tai64n"
"golang.zx2c4.com/wireguard/tai64n"
)
type handshakeState int
@ -52,11 +52,11 @@ const (
WGLabelCookie = "cookie--"
)
var (
MessageInitiationType uint32 = 1
MessageResponseType uint32 = 2
MessageCookieReplyType uint32 = 3
MessageTransportType uint32 = 4
const (
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
)
const (
@ -75,10 +75,6 @@ const (
MessageTransportOffsetContent = 16
)
var packetSizeToMsgType map[int]uint32
var msgTypeToJunkSize map[uint32]int
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now)
@ -179,6 +175,8 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
errZeroECDHResult := errors.New("ECDH returned all zeros")
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@ -197,20 +195,18 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
device.aSecMux.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
device.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil {
return nil, err
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) {
return nil, errZeroECDHResult
}
var key [chacha20poly1305.KeySize]byte
KDF2(
@ -225,7 +221,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errInvalidPublicKey
return nil, errZeroECDHResult
}
KDF2(
&handshake.chainKey,
@ -256,12 +252,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
device.aSecMux.RLock()
if msg.Type != MessageInitiationType {
device.aSecMux.RUnlock()
return nil
}
device.aSecMux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@ -271,10 +264,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
var err error
var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if err != nil {
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
@ -288,7 +282,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer
peer := device.LookupPeer(peerPK)
if peer == nil || !peer.isRunning.Load() {
if peer == nil || !peer.isRunning.Get() {
return nil
}
@ -376,9 +370,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
device.aSecMux.RLock()
msg.Type = MessageResponseType
device.aSecMux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@ -392,16 +384,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])
func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
handshake.mixKey(ss[:])
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.mixKey(ss[:])
}()
// add preshared key
@ -418,9 +406,11 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.mixHash(msg.Empty[:])
func() {
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.mixHash(msg.Empty[:])
}()
handshake.state = handshakeResponseCreated
@ -428,12 +418,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.aSecMux.RLock()
if msg.Type != MessageResponseType {
device.aSecMux.RUnlock()
return nil
}
device.aSecMux.RUnlock()
// lookup handshake by receiver
@ -468,19 +455,17 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
if err != nil {
return false
}
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if err != nil {
return false
}
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
func() {
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
// add preshared key (psk)
@ -498,7 +483,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
@ -596,12 +581,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
next := keypairs.next.Load()
next := keypairs.loadNext()
current := keypairs.current
if isInitiator {
if next != nil {
keypairs.next.Store(nil)
keypairs.storeNext(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@ -610,7 +595,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
keypairs.next.Store(keypair)
keypairs.storeNext(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@ -622,18 +607,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
if keypairs.next.Load() != receivedKeypair {
if keypairs.loadNext() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
if keypairs.next.Load() != receivedKeypair {
if keypairs.loadNext() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next.Load()
keypairs.next.Store(nil)
keypairs.current = keypairs.loadNext()
keypairs.storeNext(nil)
return true
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -10,8 +10,8 @@ import (
"encoding/binary"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
func TestCurveWrappers(t *testing.T) {
@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
ss1, err1 := sk1.sharedSecret(pk2)
ss2, err2 := sk2.sharedSecret(pk1)
ss1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1)
if ss1 != ss2 || err1 != nil || err2 != nil {
if ss1 != ss2 {
t.Fatal("Failed to compute shared secet")
}
}
@ -148,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
key1 := peer1.keypairs.next.Load()
key1 := peer1.keypairs.loadNext()
key2 := peer2.keypairs.current
// encrypting / decryption test

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -12,35 +12,40 @@ import (
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.zx2c4.com/wireguard/conn"
)
type Peer struct {
isRunning atomic.Bool
keypairs Keypairs
handshake Handshake
device *Device
stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint)
rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch
isRunning AtomicBool
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs
handshake Handshake
device *Device
endpoint conn.Endpoint
stopping sync.WaitGroup // routines pending stop
endpoint struct {
sync.Mutex
val conn.Endpoint
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
disableRoaming bool
// These fields are accessed with atomic operations, which must be
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
// allocated struct will be 64-bit aligned. So we place
// atomically-accessed fields up front, so that they can share in
// this alignment before smaller fields throw it off.
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
}
disableRoaming bool
timers struct {
retransmitHandshake *Timer
sendKeepalive *Timer
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
handshakeAttempts atomic.Uint32
needAnotherKeepalive atomic.Bool
sentLastMinuteHandshake atomic.Bool
handshakeAttempts uint32
needAnotherKeepalive AtomicBool
sentLastMinuteHandshake AtomicBool
}
state struct {
@ -48,14 +53,14 @@ type Peer struct {
}
queue struct {
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing
staged chan *QueueOutboundElement // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
trieEntries list.List
persistentKeepaliveInterval atomic.Uint32
persistentKeepaliveInterval uint32 // accessed atomically
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@ -77,12 +82,14 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
peer.Lock()
defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
@ -93,16 +100,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
peer.endpoint.Lock()
peer.endpoint.val = nil
peer.endpoint.disableRoaming = false
peer.endpoint.clearSrcOnTx = false
peer.endpoint.Unlock()
peer.endpoint = nil
// init timers
peer.timersInit()
@ -113,7 +116,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
func (peer *Peer) SendBuffers(buffers [][]byte) error {
func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@ -121,25 +124,16 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
return nil
}
peer.endpoint.Lock()
endpoint := peer.endpoint.val
if endpoint == nil {
peer.endpoint.Unlock()
peer.RLock()
defer peer.RUnlock()
if peer.endpoint == nil {
return errors.New("no known endpoint for peer")
}
if peer.endpoint.clearSrcOnTx {
endpoint.ClearSrc()
peer.endpoint.clearSrcOnTx = false
}
peer.endpoint.Unlock()
err := peer.device.net.bind.Send(buffers, endpoint)
err := peer.device.net.bind.Send(buffer, peer.endpoint)
if err == nil {
var totalLen uint64
for _, b := range buffers {
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
}
return err
}
@ -180,7 +174,7 @@ func (peer *Peer) Start() {
peer.state.Lock()
defer peer.state.Unlock()
if peer.isRunning.Load() {
if peer.isRunning.Get() {
return
}
@ -201,14 +195,10 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound)
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
// Use the device batch size, not the bind batch size, as the device size is
// the size of the batch pools.
batchSize := peer.device.BatchSize()
go peer.RoutineSequentialSender(batchSize)
go peer.RoutineSequentialReceiver(batchSize)
peer.isRunning.Store(true)
peer.isRunning.Set(true)
}
func (peer *Peer) ZeroAndFlushAll() {
@ -220,10 +210,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next.Load())
device.DeleteKeypair(keypairs.loadNext())
keypairs.previous = nil
keypairs.current = nil
keypairs.next.Store(nil)
keypairs.storeNext(nil)
keypairs.Unlock()
// clear handshake state
@ -248,10 +238,11 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
keypairs.current.sendNonce.Store(RejectAfterMessages)
atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
}
if next := keypairs.next.Load(); next != nil {
next.sendNonce.Store(RejectAfterMessages)
if keypairs.next != nil {
next := keypairs.loadNext()
atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
}
keypairs.Unlock()
}
@ -277,20 +268,10 @@ func (peer *Peer) Stop() {
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.disableRoaming {
if peer.disableRoaming {
return
}
peer.endpoint.clearSrcOnTx = false
peer.endpoint.val = endpoint
}
func (peer *Peer) markEndpointSrcForClearing() {
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.val == nil {
return
}
peer.endpoint.clearSrcOnTx = true
peer.Lock()
peer.endpoint = endpoint
peer.Unlock()
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -14,7 +14,7 @@ type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count atomic.Uint32
count uint32
max uint32
}
@ -27,10 +27,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
for p.count.Load() >= p.max {
for atomic.LoadUint32(&p.count) >= p.max {
p.cond.Wait()
}
p.count.Add(1)
atomic.AddUint32(&p.count, 1)
p.lock.Unlock()
}
return p.pool.Get()
@ -41,19 +41,11 @@ func (p *WaitPool) Put(x any) {
if p.max == 0 {
return
}
p.count.Add(^uint32(0))
atomic.AddUint32(&p.count, ^uint32(0))
p.cond.Signal()
}
func (device *Device) PopulatePools() {
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &QueueInboundElementsContainer{elems: s}
})
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
@ -65,34 +57,6 @@ func (device *Device) PopulatePools() {
})
}
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
c.elems = c.elems[:0]
device.pool.inboundElementsContainer.Put(c)
}
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
c.elems = c.elems[:0]
device.pool.outboundElementsContainer.Put(c)
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -17,31 +17,29 @@ import (
func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled")
var wg sync.WaitGroup
var trials atomic.Int32
startTrials := int32(100000)
trials := int32(100000)
if raceEnabled {
// This test can be very slow with -race.
startTrials /= 10
trials /= 10
}
trials.Store(startTrials)
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
t.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
var max atomic.Uint32
max := uint32(0)
updateMax := func() {
count := p.count.Load()
count := atomic.LoadUint32(&p.count)
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}
for {
old := max.Load()
old := atomic.LoadUint32(&max)
if count <= old {
break
}
if max.CompareAndSwap(old, count) {
if atomic.CompareAndSwapUint32(&max, old, count) {
break
}
}
@ -49,7 +47,7 @@ func TestWaitPool(t *testing.T) {
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
for atomic.AddInt32(&trials, -1) > 0 {
updateMax()
x := p.Get()
updateMax()
@ -61,15 +59,14 @@ func TestWaitPool(t *testing.T) {
}()
}
wg.Wait()
if max.Load() != p.max {
if max != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
}
}
func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
trials := int32(b.N)
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
@ -80,55 +77,7 @@ func BenchmarkWaitPool(b *testing.B) {
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkWaitPoolEmpty(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(0, func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkSyncPool(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := sync.Pool{New: func() any { return make([]byte, 16) }}
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
for atomic.AddInt32(&trials, -1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)

View file

@ -1,19 +1,17 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import "github.com/amnezia-vpn/amneziawg-go/conn"
/* Reduce memory consumption for Android */
const (
QueueStagedSize = conn.IdealBatchSize
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
MaxSegmentSize = 2200
PreallocatedBuffersPerPool = 4096
)

View file

@ -2,15 +2,13 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import "github.com/amnezia-vpn/amneziawg-go/conn"
const (
QueueStagedSize = conn.IdealBatchSize
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -11,12 +11,13 @@ import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
)
type QueueHandshakeElement struct {
@ -27,6 +28,7 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@ -34,11 +36,6 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
type QueueInboundElementsContainer struct {
sync.Mutex
elems []*QueueInboundElement
}
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
@ -55,12 +52,12 @@ func (elem *QueueInboundElement) clearPointers() {
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake.Load() {
if peer.timers.sentLastMinuteHandshake.Get() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Store(true)
peer.timers.sentLastMinuteHandshake.Set(true)
peer.SendHandshakeInitiation(false)
}
}
@ -70,10 +67,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(
maxBatchSize int,
recv conn.ReceiveFunc,
) {
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@ -86,33 +80,20 @@ func (device *Device) RoutineReceiveIncoming(
// receive datagrams until conn is closed
buffer := device.GetMessageBuffer()
var (
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
bufs = make([][]byte, maxBatchSize)
err error
sizes = make([]int, maxBatchSize)
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
size int
endpoint conn.Endpoint
deathSpiral int
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
for i := range bufsArrs {
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
}
defer func() {
for i := 0; i < maxBatchSize; i++ {
if bufsArrs[i] != nil {
device.PutMessageBuffer(bufsArrs[i])
}
}
}()
for {
count, err = recv(bufs, sizes, endpoints)
size, endpoint, err = recv(buffer[:])
if err != nil {
device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) {
return
}
@ -123,142 +104,101 @@ func (device *Device) RoutineReceiveIncoming(
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
buffer = device.GetMessageBuffer()
continue
}
return
}
deathSpiral = 0
device.aSecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
if size < MinMessageSize {
continue
}
// check size of packet
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
var okay bool
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
continue
}
// check size of packet
// lookup key pair
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAdvancedSecurityOn() {
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
// making sure we have the right msgType
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
device.log.Verbosef("Transport packet lined up with another msg type")
msgType = binary.LittleEndian.Uint32(packet[:4])
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
if msgType != MessageTransportType {
device.log.Verbosef("ASec: Received message with unknown type")
continue
}
}
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = buffer
elem.keypair = keypair
elem.endpoint = endpoint
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
// add to decryption queues
if peer.isRunning.Get() {
peer.queue.inbound.c <- elem
device.queue.decryption.c <- elem
buffer = device.GetMessageBuffer()
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
device.PutInboundElement(elem)
}
switch msgType {
continue
// check if transport
// otherwise it is a fixed size & handshake related packet
case MessageTransportType:
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
// check size
case MessageResponseType:
okay = len(packet) == MessageResponseSize
if len(packet) < MessageTransportSize {
continue
}
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = bufsArrs[i]
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetInboundElementsContainer()
elemsForPeer.Lock()
elemsByPeer[peer] = elemsForPeer
}
elemsForPeer.elems = append(elemsForPeer.elems, elem)
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType:
if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType:
if len(packet) != MessageCookieReplySize {
continue
}
default:
device.log.Verbosef("Received message with unknown type")
continue
}
default:
device.log.Verbosef("Received message with unknown type")
}
if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
buffer: bufsArrs[i],
buffer: buffer,
packet: packet,
endpoint: endpoints[i],
endpoint: endpoint,
}:
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
buffer = device.GetMessageBuffer()
default:
}
}
device.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
device.queue.decryption.c <- elemsContainer
} else {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
}
delete(elemsByPeer, peer)
}
}
}
@ -268,28 +208,26 @@ func (device *Device) RoutineDecryption(id int) {
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
for elemsContainer := range device.queue.decryption.c {
for _, elem := range elemsContainer.elems {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
for elem := range device.queue.decryption.c {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
}
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
}
elemsContainer.Unlock()
elem.Unlock()
}
}
@ -304,8 +242,6 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
device.aSecMux.RLock()
// handle cookie fields and ratelimiting
switch elem.msgType {
@ -332,15 +268,10 @@ func (device *Device) RoutineHandshake(id int) {
// consume reply
if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef(
"Receiving cookie response from %s",
elem.endpoint.DstToString(),
)
if peer := entry.peer; peer.isRunning.Get() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef(
"Could not decrypt invalid cookie response",
)
device.log.Verbosef("Could not decrypt invalid cookie response")
}
}
@ -382,7 +313,9 @@ func (device *Device) RoutineHandshake(id int) {
switch elem.msgType {
case MessageInitiationType:
// unmarshal
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
@ -392,6 +325,7 @@ func (device *Device) RoutineHandshake(id int) {
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
@ -407,7 +341,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
peer.SendHandshakeResponse()
@ -435,7 +369,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
// update timers
@ -456,12 +390,11 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@ -469,109 +402,89 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
bufs := make([][]byte, 0, maxBatchSize)
for elemsContainer := range peer.queue.inbound.c {
if elemsContainer == nil {
for elem := range peer.queue.inbound.c {
if elem == nil {
return
}
elemsContainer.Lock()
validTailPacket := -1
dataPacketReceived := false
rxBytesLen := uint64(0)
for i, elem := range elemsContainer.elems {
if elem.packet == nil {
// decryption failed
continue
}
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
validTailPacket = i
if peer.ReceivedWithKeypair(elem.keypair) {
peer.SetEndpointFromPacket(elem.endpoint)
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
dataPacketReceived = true
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
continue
}
case 6:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
continue
}
default:
device.log.Verbosef(
"Packet with invalid IP version from %v",
peer,
)
continue
}
bufs = append(
bufs,
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
)
var err error
elem.Lock()
if elem.packet == nil {
// decryption failed
goto skip
}
peer.rxBytes.Add(rxBytesLen)
if validTailPacket >= 0 {
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
goto skip
}
if dataPacketReceived {
peer.timersDataReceived()
peer.SetEndpointFromPacket(elem.endpoint)
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
if len(bufs) > 0 {
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
goto skip
}
peer.timersDataReceived()
switch elem.packet[0] >> 4 {
case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
goto skip
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
goto skip
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
goto skip
}
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
goto skip
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
goto skip
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
goto skip
}
default:
device.log.Verbosef("Packet with invalid IP version from %v", peer)
goto skip
}
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packet to TUN device: %v", err)
}
if len(peer.queue.inbound.c) == 0 {
err = device.tun.device.Flush()
if err != nil {
peer.device.log.Errorf("Unable to flush packets: %v", err)
}
}
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
bufs = bufs[:0]
device.PutInboundElementsContainer(elemsContainer)
skip:
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -12,10 +12,9 @@ import (
"net"
"os"
"sync"
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@ -46,6 +45,7 @@ import (
*/
type QueueOutboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@ -53,14 +53,10 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
type QueueOutboundElementsContainer struct {
sync.Mutex
elems []*QueueOutboundElement
}
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer()
elem.Mutex = sync.Mutex{}
elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers.
return elem
@ -80,17 +76,14 @@ func (elem *QueueOutboundElement) clearPointers() {
/* Queues a keepalive if no packets are queued for peer
*/
func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
elem := peer.device.NewOutboundElement()
elemsContainer := peer.device.GetOutboundElementsContainer()
elemsContainer.elems = append(elemsContainer.elems, elem)
select {
case peer.queue.staged <- elemsContainer:
case peer.queue.staged <- elem:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
peer.device.PutOutboundElementsContainer(elemsContainer)
}
}
peer.SendStagedPackets()
@ -98,7 +91,7 @@ func (peer *Peer) SendKeepalive() {
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
peer.timers.handshakeAttempts.Store(0)
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
}
peer.handshake.mutex.RLock()
@ -123,56 +116,17 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
return err
}
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.device.junkCreator.createJunkPackets()
peer.device.aSecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
if len(junks) > 0 {
err = peer.SendBuffers(junks)
if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
return err
}
}
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
writer := bytes.NewBuffer(buf[:0])
var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
sendBuffer = append(sendBuffer, junkedHeader)
err = peer.SendBuffers(sendBuffer)
err = peer.SendBuffer(packet)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@ -193,29 +147,12 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
return err
}
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
var buff [MessageResponseSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
err = peer.BeginSymmetricSession()
if err != nil {
@ -227,35 +164,27 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
// TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{junkedHeader})
err = peer.SendBuffer(packet)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
return err
}
func (device *Device) SendHandshakeCookie(
initiatingElem *QueueHandshakeElement,
) error {
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(
initiatingElem.packet,
sender,
initiatingElem.endpoint.DstToBytes(),
)
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err)
return err
}
var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0])
var buff [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply)
// TODO: allocation could be avoided
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
return nil
}
@ -264,12 +193,17 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil {
return
}
nonce := keypair.sendNonce.Load()
nonce := atomic.LoadUint64(&keypair.sendNonce)
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
/* Reads packets from the TUN and inserts
* into staged queue for peer
*
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN() {
defer func() {
device.log.Verbosef("Routine: TUN reader - stopped")
@ -279,123 +213,81 @@ func (device *Device) RoutineReadFromTUN() {
device.log.Verbosef("Routine: TUN reader - started")
var (
batchSize = device.BatchSize()
readErr error
elems = make([]*QueueOutboundElement, batchSize)
bufs = make([][]byte, batchSize)
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
count = 0
sizes = make([]int, batchSize)
offset = MessageTransportHeaderSize
)
for i := range elems {
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
}
defer func() {
for _, elem := range elems {
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
}
}()
var elem *QueueOutboundElement
for {
// read packets
count, readErr = device.tun.device.Read(bufs, sizes, offset)
for i := 0; i < count; i++ {
if sizes[i] < 1 {
continue
}
elem := elems[i]
elem.packet = bufs[i][offset : offset+sizes[i]]
// lookup peer
var peer *Peer
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.Lookup(dst)
case 6:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.Lookup(dst)
default:
device.log.Verbosef("Received packet with unknown IP version")
}
if peer == nil {
continue
}
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetOutboundElementsContainer()
elemsByPeer[peer] = elemsForPeer
}
elemsForPeer.elems = append(elemsForPeer.elems, elem)
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
elem = device.NewOutboundElement()
for peer, elemsForPeer := range elemsByPeer {
if peer.isRunning.Load() {
peer.StagePackets(elemsForPeer)
peer.SendStagedPackets()
} else {
for _, elem := range elemsForPeer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsForPeer)
}
delete(elemsByPeer, peer)
}
// read packet
if readErr != nil {
if errors.Is(readErr, tun.ErrTooManySegments) {
// TODO: record stat for this
// This will happen if MSS is surprisingly small (< 576)
// coincident with reasonably high throughput.
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
continue
}
offset := MessageTransportHeaderSize
size, err := device.tun.device.Read(elem.buffer[:], offset)
if err != nil {
if !device.isClosed() {
if !errors.Is(readErr, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
if !errors.Is(err, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", err)
}
go device.Close()
}
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return
}
if size == 0 || size > MaxContentSize {
continue
}
elem.packet = elem.buffer[offset : offset+size]
// lookup peer
var peer *Peer
switch elem.packet[0] >> 4 {
case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.Lookup(dst)
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.Lookup(dst)
default:
device.log.Verbosef("Received packet with unknown IP version")
}
if peer == nil {
continue
}
if peer.isRunning.Get() {
peer.StagePacket(elem)
elem = nil
peer.SendStagedPackets()
}
}
}
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
for {
select {
case peer.queue.staged <- elems:
case peer.queue.staged <- elem:
return
default:
}
select {
case tooOld := <-peer.queue.staged:
for _, elem := range tooOld.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(tooOld)
peer.device.PutMessageBuffer(tooOld.buffer)
peer.device.PutOutboundElement(tooOld)
default:
}
}
@ -408,59 +300,32 @@ top:
}
keypair := peer.keypairs.Current()
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false)
return
}
for {
var elemsContainerOOO *QueueOutboundElementsContainer
select {
case elemsContainer := <-peer.queue.staged:
i := 0
for _, elem := range elemsContainer.elems {
elem.peer = peer
elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages {
keypair.sendNonce.Store(RejectAfterMessages)
if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
continue
} else {
elemsContainer.elems[i] = elem
i++
}
elem.keypair = keypair
}
elemsContainer.Lock()
elemsContainer.elems = elemsContainer.elems[:i]
if elemsContainerOOO != nil {
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
}
if len(elemsContainer.elems) == 0 {
peer.device.PutOutboundElementsContainer(elemsContainer)
case elem := <-peer.queue.staged:
elem.peer = peer
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
if elem.nonce >= RejectAfterMessages {
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
goto top
}
elem.keypair = keypair
elem.Lock()
// add to parallel and sequential queue
if peer.isRunning.Load() {
peer.queue.outbound.c <- elemsContainer
peer.device.queue.encryption.c <- elemsContainer
if peer.isRunning.Get() {
peer.queue.outbound.c <- elem
peer.device.queue.encryption.c <- elem
} else {
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(elemsContainer)
}
if elemsContainerOOO != nil {
goto top
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
default:
return
@ -471,12 +336,9 @@ top:
func (peer *Peer) FlushStagedPackets() {
for {
select {
case elemsContainer := <-peer.queue.staged:
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(elemsContainer)
case elem := <-peer.queue.staged:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
default:
return
}
@ -510,38 +372,41 @@ func (device *Device) RoutineEncryption(id int) {
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker %d - started", id)
for elemsContainer := range device.queue.encryption.c {
for _, elem := range elemsContainer.elems {
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
for elem := range device.queue.encryption.c {
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
}
elemsContainer.Unlock()
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
elem.Unlock()
}
}
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
/* Sequentially reads packets from queue and sends to endpoint
*
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequentialSender() {
device := peer.device
defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
@ -549,57 +414,36 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
bufs := make([][]byte, 0, maxBatchSize)
for elemsContainer := range peer.queue.outbound.c {
bufs = bufs[:0]
if elemsContainer == nil {
for elem := range peer.queue.outbound.c {
if elem == nil {
return
}
if !peer.isRunning.Load() {
elem.Lock()
if !peer.isRunning.Get() {
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed.
// The timers and SendBuffers code are resilient to a few stragglers.
// The timers and SendBuffer code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
continue
}
dataSent := false
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
}
bufs = append(bufs, elem.packet)
}
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err := peer.SendBuffers(bufs)
if dataSent {
// send message and return buffer to pool
err := peer.SendBuffer(elem.packet)
if len(elem.packet) != MessageKeepaliveSize {
peer.timersDataSent()
}
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
if err != nil {
var errGSO conn.ErrUDPGSODisabled
if errors.As(err, &errGSO) {
device.log.Verbosef(err.Error())
err = errGSO.RetryErr
}
}
if err != nil {
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
continue
}

View file

@ -3,8 +3,8 @@
package device
import (
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
@ -20,15 +20,12 @@ import (
"golang.org/x/sys/unix"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
if !conn.StdNetSupportsStickySockets {
return nil, nil
}
if _, ok := bind.(*conn.StdNetBind); !ok {
if _, ok := bind.(*conn.LinuxSocketBind); !ok {
return nil, nil
}
@ -110,17 +107,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
pePtr.peer.endpoint.Lock()
if &pePtr.peer.endpoint.val != pePtr.endpoint {
pePtr.peer.endpoint.Unlock()
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
break
}
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.endpoint.Unlock()
if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
pePtr.peer.Unlock()
break
}
pePtr.peer.endpoint.clearSrcOnTx = true
pePtr.peer.endpoint.Unlock()
pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
}
@ -134,18 +131,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.endpoint.Lock()
if peer.endpoint.val == nil {
peer.endpoint.Unlock()
peer.RLock()
if peer.endpoint == nil {
peer.RUnlock()
continue
}
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
if nativeEP == nil {
peer.endpoint.Unlock()
peer.RUnlock()
continue
}
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.endpoint.Unlock()
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
peer.RUnlock()
break
}
nlmsg := struct {
@ -172,12 +169,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_DST,
},
nativeEP.DstIP().As4(),
nativeEP.Dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
nativeEP.SrcIP().As4(),
nativeEP.Src4().Src,
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
@ -188,10 +185,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint.val,
endpoint: &peer.endpoint,
}
reqPeerLock.Unlock()
peer.endpoint.Unlock()
peer.RUnlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
@ -207,7 +204,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
@ -9,6 +9,7 @@ package device
import (
"sync"
"sync/atomic"
"time"
_ "unsafe"
)
@ -73,11 +74,11 @@ func (timer *Timer) IsPending() bool {
}
func (peer *Peer) timersActive() bool {
return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp()
}
func expiredRetransmitHandshake(peer *Peer) {
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
@ -96,11 +97,15 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
} else {
peer.timers.handshakeAttempts.Add(1)
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.markEndpointSrcForClearing()
peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(true)
}
@ -108,8 +113,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive()
if peer.timers.needAnotherKeepalive.Load() {
peer.timers.needAnotherKeepalive.Store(false)
if peer.timers.needAnotherKeepalive.Get() {
peer.timers.needAnotherKeepalive.Set(false)
if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
}
@ -119,7 +124,11 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.markEndpointSrcForClearing()
peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(false)
}
@ -129,7 +138,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
}
func expiredPersistentKeepalive(peer *Peer) {
if peer.persistentKeepaliveInterval.Load() > 0 {
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
peer.SendKeepalive()
}
}
@ -147,7 +156,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else {
peer.timers.needAnotherKeepalive.Store(true)
peer.timers.needAnotherKeepalive.Set(true)
}
}
}
@ -178,9 +187,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() {
peer.timers.retransmitHandshake.Del()
}
peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Store(false)
peer.lastHandshakeNano.Store(time.Now().UnixNano())
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
peer.timers.sentLastMinuteHandshake.Set(false)
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
}
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@ -192,7 +201,7 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
keepalive := peer.persistentKeepaliveInterval.Load()
keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
if keepalive > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
}
@ -207,9 +216,9 @@ func (peer *Peer) timersInit() {
}
func (peer *Peer) timersStart() {
peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Store(false)
peer.timers.needAnotherKeepalive.Store(false)
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
peer.timers.sentLastMinuteHandshake.Set(false)
peer.timers.needAnotherKeepalive.Set(false)
}
func (peer *Peer) timersStop() {

View file

@ -1,14 +1,15 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"fmt"
"sync/atomic"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.zx2c4.com/wireguard/tun"
)
const DefaultMTU = 1420
@ -32,7 +33,7 @@ func (device *Device) RoutineTUNEventReader() {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize
}
old := device.tun.mtu.Swap(int32(mtu))
old := atomic.SwapInt32(&device.tun.mtu, int32(mtu))
if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
@ -16,9 +16,10 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"golang.zx2c4.com/wireguard/ipc"
)
type IPCError struct {
@ -97,63 +98,35 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount)
}
if device.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
}
if device.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
}
if device.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
}
if device.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
}
if device.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
}
if device.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
}
if device.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
}
if device.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
}
}
for _, peer := range device.peers.keyMap {
// Serialize peer state.
peer.handshake.mutex.RLock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
peer.endpoint.Lock()
if peer.endpoint.val != nil {
sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
peer.endpoint.Unlock()
// Do the work in an anonymous function so that we can use defer.
func() {
peer.RLock()
defer peer.RUnlock()
nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
sendf("protocol_version=1")
if peer.endpoint != nil {
sendf("endpoint=%s", peer.endpoint.DstToString())
}
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
}()
}
}()
@ -180,26 +153,17 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
tempASecCfg := aSecCfgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
err := device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
peer.handlePostConfig()
return nil
}
key, value, ok := strings.Cut(line, "=")
if !ok {
return ipcErrorf(
ipc.IpcErrorProtocol,
"failed to parse line %q",
line,
)
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
}
if key == "public_key" {
@ -217,7 +181,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value, &tempASecCfg)
err = device.handleDeviceLine(key, value)
} else {
err = device.handlePeerLine(peer, key, value)
}
@ -225,10 +189,6 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
err = device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
peer.handlePostConfig()
if err := scanner.Err(); err != nil {
@ -237,7 +197,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
func (device *Device) handleDeviceLine(key, value string) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@ -283,83 +243,6 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempASecCfg.junkPacketCount = junkPacketCount
tempASecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempASecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempASecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempASecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempASecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempASecCfg.isSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
}
@ -380,7 +263,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
return
}
if peer.created {
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
}
if peer.device.isUp() {
peer.Start()
@ -391,10 +274,7 @@ func (peer *ipcSetPeer) handlePostConfig() {
}
}
func (device *Device) handlePublicKeyLine(
peer *ipcSetPeer,
value string,
) error {
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
// Load/create the peer we are configuring.
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
@ -424,10 +304,7 @@ func (device *Device) handlePublicKeyLine(
return nil
}
func (device *Device) handlePeerLine(
peer *ipcSetPeer,
key, value string,
) error {
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
switch key {
case "update_only":
// allow disabling of creation
@ -469,9 +346,9 @@ func (device *Device) handlePeerLine(
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
peer.endpoint.val = endpoint
peer.Lock()
defer peer.Unlock()
peer.endpoint = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
@ -481,7 +358,7 @@ func (device *Device) handlePeerLine(
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
}
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on.
peer.pkaOn = old == 0 && secs != 0

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
*/
package main

19
go.mod
View file

@ -1,17 +1,10 @@
module github.com/amnezia-vpn/amneziawg-go
module golang.zx2c4.com/wireguard
go 1.24
go 1.18
require (
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.36.0
golang.org/x/net v0.37.0
golang.org/x/sys v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
)
require (
github.com/google/btree v1.1.3 // indirect
golang.org/x/time v0.9.0 // indirect
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
golang.org/x/net v0.0.0-20220225172249-27dd8689420f
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
)

28
go.sum
View file

@ -1,20 +1,8 @@
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 h1:A9i04dxx7Cribqbs8jf3FQLogkL/CV2YN7hj9KWJCkc=
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=

View file

@ -4,6 +4,7 @@
// license that can be found in the LICENSE file.
//go:build windows
// +build windows
package namedpipe
@ -53,7 +54,7 @@ type file struct {
handle windows.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomic.Bool
closing uint32 // used as atomic boolean
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
@ -64,7 +65,7 @@ type deadlineHandler struct {
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
timedout atomic.Bool
timedout uint32 // used as atomic boolean
}
// makeFile makes a new file from an existing file handle
@ -88,7 +89,7 @@ func makeFile(h windows.Handle) (*file, error) {
func (f *file) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
if f.closing.Swap(true) == false {
if atomic.SwapUint32(&f.closing, 1) == 0 {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
windows.CancelIoEx(f.handle, nil)
@ -111,7 +112,7 @@ func (f *file) Close() error {
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
if f.closing.Load() {
if atomic.LoadUint32(&f.closing) == 1 {
f.wgLock.RUnlock()
return nil, os.ErrClosed
}
@ -143,7 +144,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
return int(bytes), err
}
if f.closing.Load() {
if atomic.LoadUint32(&f.closing) == 1 {
windows.CancelIoEx(f.handle, &c.o)
}
@ -159,7 +160,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
case r = <-c.ch:
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
if f.closing.Load() {
if atomic.LoadUint32(&f.closing) == 1 {
err = os.ErrClosed
}
} else if err != nil && f.socket {
@ -191,7 +192,7 @@ func (f *file) Read(b []byte) (int, error) {
}
defer f.wg.Done()
if f.readDeadline.timedout.Load() {
if atomic.LoadUint32(&f.readDeadline.timedout) == 1 {
return 0, os.ErrDeadlineExceeded
}
@ -218,7 +219,7 @@ func (f *file) Write(b []byte) (int, error) {
}
defer f.wg.Done()
if f.writeDeadline.timedout.Load() {
if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 {
return 0, os.ErrDeadlineExceeded
}
@ -255,7 +256,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
}
d.timer = nil
}
d.timedout.Store(false)
atomic.StoreUint32(&d.timedout, 0)
select {
case <-d.channel:
@ -270,7 +271,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
}
timeoutIO := func() {
d.timedout.Store(true)
atomic.StoreUint32(&d.timedout, 1)
close(d.channel)
}

View file

@ -4,6 +4,7 @@
// license that can be found in the LICENSE file.
//go:build windows
// +build windows
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
package namedpipe
@ -28,7 +29,7 @@ type pipe struct {
type messageBytePipe struct {
pipe
writeClosed atomic.Bool
writeClosed int32
readEOF bool
}
@ -50,17 +51,17 @@ func (f *pipe) SetDeadline(t time.Time) error {
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error {
if !f.writeClosed.CompareAndSwap(false, true) {
if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) {
return io.ErrClosedPipe
}
err := f.file.Flush()
if err != nil {
f.writeClosed.Store(false)
atomic.StoreInt32(&f.writeClosed, 0)
return err
}
_, err = f.file.Write(nil)
if err != nil {
f.writeClosed.Store(false)
atomic.StoreInt32(&f.writeClosed, 0)
return err
}
return nil
@ -69,7 +70,7 @@ func (f *messageBytePipe) CloseWrite() error {
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed.Load() {
if atomic.LoadInt32(&f.writeClosed) != 0 {
return 0, io.ErrClosedPipe
}
if len(b) == 0 {

View file

@ -4,6 +4,7 @@
// license that can be found in the LICENSE file.
//go:build windows
// +build windows
package namedpipe_test
@ -20,8 +21,8 @@ import (
"testing"
"time"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
func randomPipePath() string {

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package ipc

View file

@ -1,11 +1,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
*/
package ipc
// Made up sentinel error codes for {js,wasip1}/wasm.
// Made up sentinel error codes for the js/wasm platform.
const (
IpcErrorIO = 1
IpcErrorInvalid = 2

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package ipc
@ -9,8 +9,8 @@ import (
"net"
"os"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
type UAPIListener struct {
@ -96,7 +96,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
}
go func(l *UAPIListener) {
var buf [0]byte
var buff [0]byte
for {
defer uapi.inotifyRWCancel.Close()
// start with lstat to avoid race condition
@ -104,7 +104,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err
return
}
_, err := uapi.inotifyRWCancel.Read(buf[:])
_, err := uapi.inotifyRWCancel.Read(buff[:])
if err != nil {
l.connErr <- err
return

View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package ipc
@ -26,7 +26,7 @@ const (
// socketDirectory is variable because it is modified by a linker
// flag in wireguard-android.
var socketDirectory = "/var/run/amneziawg"
var socketDirectory = "/var/run/wireguard"
func sockPath(iface string) string {
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package ipc
@ -8,8 +8,8 @@ package ipc
import (
"net"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
// TODO: replace these with actual standard windows error numbers from the win package
@ -62,7 +62,7 @@ func init() {
func UAPIListen(name string) (net.Listener, error) {
listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor,
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
if err != nil {
return nil, err
}

46
main.go
View file

@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package main
@ -13,12 +13,12 @@ import (
"os/signal"
"runtime"
"strconv"
"syscall"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
)
const (
@ -46,20 +46,20 @@ func warning() {
return
}
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
}
func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
return
}
@ -111,7 +111,7 @@ func main() {
// open TUN device (or use supplied fd)
tdev, err := func() (tun.Device, error) {
tun, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU)
@ -124,7 +124,7 @@ func main() {
return nil, err
}
err = unix.SetNonblock(int(fd), true)
err = syscall.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
@ -134,7 +134,7 @@ func main() {
}()
if err == nil {
realInterfaceName, err2 := tdev.Name()
realInterfaceName, err2 := tun.Name()
if err2 == nil {
interfaceName = realInterfaceName
}
@ -145,7 +145,7 @@ func main() {
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Verbosef("Starting amneziawg-go version %s", Version)
logger.Verbosef("Starting wireguard-go version %s", Version)
if err != nil {
logger.Errorf("Failed to create TUN device: %v", err)
@ -196,7 +196,7 @@ func main() {
files[0], // stdin
files[1], // stdout
files[2], // stderr
tdev.File(),
tun.File(),
fileUAPI,
},
Dir: ".",
@ -222,7 +222,7 @@ func main() {
return
}
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
logger.Verbosef("Device started")
@ -250,7 +250,7 @@ func main() {
// wait for program to terminate
signal.Notify(term, unix.SIGTERM)
signal.Notify(term, syscall.SIGTERM)
signal.Notify(term, os.Interrupt)
select {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package main
@ -9,14 +9,13 @@ import (
"fmt"
"os"
"os/signal"
"syscall"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.zx2c4.com/wireguard/tun"
)
const (
@ -30,13 +29,13 @@ func main() {
}
interfaceName := os.Args[1]
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
logger := device.NewLogger(
device.LogLevelVerbose,
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Verbosef("Starting amneziawg-go version %s", Version)
logger.Verbosef("Starting wireguard-go version %s", Version)
tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil {
@ -82,7 +81,7 @@ func main() {
signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill)
signal.Notify(term, windows.SIGTERM)
signal.Notify(term, syscall.SIGTERM)
select {
case <-term:

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package replay

View file

@ -1,8 +1,8 @@
//go:build !windows && !wasm
//go:build !windows && !js
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
// Package rwcancel implements cancelable read/write operations on

View file

@ -1,4 +1,4 @@
//go:build windows || wasm
//go:build windows || js
// SPDX-License-Identifier: MIT

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package tai64n

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package tai64n

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package tun

View file

@ -1,118 +0,0 @@
package tun
import "encoding/binary"
// TODO: Explore SIMD and/or other assembly optimizations.
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
for len(b) >= 128 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
b = b[128:]
}
if len(b) >= 64 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
b = b[64:]
}
if len(b) >= 32 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
b = b[32:]
}
if len(b) >= 16 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
b = b[16:]
}
if len(b) >= 8 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
b = b[8:]
}
if len(b) >= 4 {
ac += uint64(binary.BigEndian.Uint32(b))
b = b[4:]
}
if len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
}
return ac
}
func checksum(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
}

View file

@ -1,35 +0,0 @@
package tun
import (
"fmt"
"math/rand"
"testing"
)
func BenchmarkChecksum(b *testing.B) {
lengths := []int{
64,
128,
256,
512,
1024,
1500,
2048,
4096,
8192,
9000,
9001,
}
for _, length := range lengths {
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksum(buf, 0)
}
})
}
}

View file

@ -1,12 +0,0 @@
package tun
import (
"errors"
)
var (
// ErrTooManySegments is returned by Device.Read() when segmentation
// overflows the length of supplied buffers. This error should not cause
// reads to cease.
ErrTooManySegments = errors.New("too many segments")
)

View file

@ -1,8 +1,9 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package main
@ -13,24 +14,24 @@ import (
"net/http"
"net/netip"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func main() {
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr("192.168.4.28")},
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
endpoint=163.172.161.0:12912
allowed_ip=0.0.0.0/0
endpoint=127.0.0.1:58120
`)
err = dev.Up()
if err != nil {
@ -42,7 +43,7 @@ endpoint=127.0.0.1:58120
DialContext: tnet.DialContext,
},
}
resp, err := client.Get("http://192.168.4.29/")
resp, err := client.Get("https://www.zx2c4.com/ip")
if err != nil {
log.Panic(err)
}

View file

@ -1,8 +1,9 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package main
@ -14,9 +15,9 @@ import (
"net/http"
"net/netip"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func main() {
@ -29,10 +30,10 @@ func main() {
log.Panic(err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
listen_port=58120
public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
allowed_ip=192.168.4.28/32
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
endpoint=163.172.161.0:12912
allowed_ip=0.0.0.0/0
persistent_keepalive_interval=25
`)
dev.Up()

View file

@ -1,8 +1,9 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package main
@ -17,9 +18,9 @@ import (
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/tun/netstack"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func main() {

17
tun/netstack/go.mod Normal file
View file

@ -0,0 +1,17 @@
module golang.zx2c4.com/wireguard/tun/netstack
go 1.18
require (
golang.org/x/net v0.0.0-20220225172249-27dd8689420f
golang.zx2c4.com/wireguard v0.0.0-20220316235147-5aff28b14c24
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
)
require (
github.com/google/btree v1.0.1 // indirect
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
)

Some files were not shown because too many files have changed in this diff Show more