Compare commits

..

No commits in common. "master" and "0.0.20190908" have entirely different histories.

152 changed files with 9841 additions and 12962 deletions

View file

@ -1,41 +0,0 @@
name: build-if-tag
on:
push:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+'
env:
APP: amneziawg-go
jobs:
build:
runs-on: ubuntu-latest
name: build
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Setup metadata
uses: docker/metadata-action@v5
id: metadata
with:
images: amneziavpn/${{ env.APP }}
tags: type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
push: true
tags: ${{ steps.metadata.outputs.tags }}

5
.gitignore vendored
View file

@ -1 +1,4 @@
amneziawg-go wireguard-go
vendor
.gopath
ireallywantobuildon_linux.go

View file

View file

@ -1,17 +0,0 @@
FROM golang:1.24 as awg
COPY . /awg
WORKDIR /awg
RUN go mod download && \
go mod verify && \
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
FROM alpine:3.19
ARG AWGTOOLS_RELEASE="1.0.20241018"
RUN apk --no-cache add iproute2 iptables bash && \
cd /usr/bin/ && \
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
unzip -j alpine-3.19-amneziawg-tools.zip && \
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
ln -s /usr/bin/awg /usr/bin/wg && \
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go

View file

@ -9,23 +9,20 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build: generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --tags --dirty 2>/dev/null)" && \ tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \ ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ [ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \ echo "$$ver" > device/version.go && \
git update-index --assume-unchanged version.go || true git update-index --assume-unchanged device/version.go || true
@$(MAKE) amneziawg-go @$(MAKE) wireguard-go
amneziawg-go: $(wildcard *.go) $(wildcard */*.go) wireguard-go: $(wildcard *.go) $(wildcard */*.go)
go build -v -o "$@" go build -v -o "$@"
install: amneziawg-go install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go" @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
test:
go test ./...
clean: clean:
rm -f amneziawg-go rm -f wireguard-go
.PHONY: all clean test install generate-version-and-build .PHONY: all clean install generate-version-and-build

View file

@ -1,27 +1,24 @@
# Go Implementation of AmneziaWG # Go Implementation of [WireGuard](https://www.wireguard.com/)
AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original. This is an implementation of WireGuard in Go.
The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
## Usage ## Usage
Simply run: Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
``` ```
$ amneziawg-go wg0 $ wireguard-go wg0
``` ```
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down. This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`: To run wireguard-go without forking to the background, pass `-f` or `--foreground`:
``` ```
$ amneziawg-go -f wg0 $ wireguard-go -f wg0
``` ```
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`. To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
@ -29,24 +26,52 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
### Linux ### Linux
This will run on Linux; you should run amnezia-wg instead of using default linux kernel module. This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions.
### macOS ### macOS
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
### Windows ### Windows
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module. This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
### FreeBSD
This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
### OpenBSD
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
## Building ## Building
This requires an installation of the latest version of [Go](https://go.dev/). This requires an installation of [go](https://golang.org) ≥ 1.12.
``` ```
$ git clone https://github.com/amnezia-vpn/amneziawg-go $ git clone https://git.zx2c4.com/wireguard-go
$ cd amneziawg-go $ cd wireguard-go
$ make $ make
``` ```
## License
Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
of the Software, and to permit persons to whom the Software is furnished to do
so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,544 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var (
_ Bind = (*StdNetBind)(nil)
)
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
// (see bind_windows.go), it may fall back to StdNetBind.
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these two fields are not guarded by mu
udpAddrPool sync.Pool
msgsPool sync.Pool
blackhole4 bool
blackhole6 bool
}
func NewStdNetBind() Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
msgsPool: sync.Pool{
New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
}
return &msgs
},
},
}
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if
// supported. Typically this is a PKTINFO structure from/for control
// messages, see unix.PKTINFO for an example.
src []byte
}
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &StdNetEndpoint{
AddrPort: e,
}, nil
}
func (e *StdNetEndpoint) ClearSrc() {
if e.src != nil {
// Truncate src, no need to reallocate.
e.src = e.src[:0]
}
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
}
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var tries int
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn
v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if v4conn != nil {
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
}
if v6conn != nil {
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
}
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
s.msgsPool.Put(msgs)
}
func (s *StdNetBind) getMessages() *[]ipv6.Message {
return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize
}
return 1
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
s.ipv4PC = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
s.ipv6PC = nil
}
s.blackhole4 = false
s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
type ErrUDPGSODisabled struct {
onLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
offload := s.ipv4TxOffload
br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
br = s.ipv6PC
is6 = true
offload = s.ipv6TxOffload
}
s.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 {
as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else {
as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
}
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
return err
}
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
start int
)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for {
n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
return err
}
const (
// Exceeding these values results in EMSGSIZE. They account for layer3 and
// layer4 headers. IPv6 does not need to account for itself as the payload
// length field is self excluding.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
// This is a hard limit imposed by the kernel.
udpSegmentMaxDatagrams = 64
)
type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
maxPayloadLen := maxIPv4PayloadLen
if ep.DstIP().Is6() {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buf := range bufs {
if i > 0 {
msgLen := len(buf)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}

View file

@ -1,250 +0,0 @@
package conn
import (
"encoding/binary"
"net"
"testing"
"golang.org/x/net/ipv6"
)
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
fns, _, err := bind.Open(0)
if err != nil {
t.Fatal(err)
}
bind.Close()
bufs := make([][]byte, 1)
bufs[0] = make([]byte, 1)
sizes := make([]int, 1)
eps := make([]Endpoint, 1)
for _, fn := range fns {
// The ReceiveFuncs must not access conn-related fields on StdNetBind
// unguarded. Close() nils the conn-related fields resulting in a panic
// if they violate the mutex.
fn(bufs, sizes, eps)
}
}
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func Test_coalesceMessages(t *testing.T) {
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1").To4(),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := 0; i < got; i++ {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}
func mockGetGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_splitCoalescedMessages(t *testing.T) {
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1<<16-1)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}

View file

@ -1,601 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"encoding/binary"
"io"
"net"
"net/netip"
"strconv"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
)
const (
packetsPerRing = 1024
bytesPerPacket = 2048 - 32
receiveSpins = 15
)
type ringPacket struct {
addr WinRingEndpoint
data [bytesPerPacket]byte
}
type ringBuffer struct {
packets uintptr
head, tail uint32
id winrio.BufferId
iocp windows.Handle
isFull bool
cq winrio.Cq
mu sync.Mutex
overlapped windows.Overlapped
}
func (rb *ringBuffer) Push() *ringPacket {
for rb.isFull {
panic("ring is full")
}
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
rb.tail += 1
if rb.tail%packetsPerRing == rb.head%packetsPerRing {
rb.isFull = true
}
return ret
}
func (rb *ringBuffer) Return(count uint32) {
if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
return
}
rb.head += count
rb.isFull = false
}
type afWinRingBind struct {
sock windows.Handle
rx, tx ringBuffer
rq winrio.Rq
mu sync.Mutex
blackhole bool
}
// WinRingBind uses Windows registered I/O for fast ring buffered networking.
type WinRingBind struct {
v4, v6 afWinRingBind
mu sync.RWMutex
isOpen atomic.Uint32 // 0, 1, or 2
}
func NewDefaultBind() Bind { return NewWinRingBind() }
func NewWinRingBind() Bind {
if !winrio.Initialize() {
return NewStdNetBind()
}
return new(WinRingBind)
}
type WinRingEndpoint struct {
family uint16
data [30]byte
}
var (
_ Bind = (*WinRingBind)(nil)
_ Endpoint = (*WinRingEndpoint)(nil)
)
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
host, port, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
host16, err := windows.UTF16PtrFromString(host)
if err != nil {
return nil, err
}
port16, err := windows.UTF16PtrFromString(port)
if err != nil {
return nil, err
}
hints := windows.AddrinfoW{
Flags: windows.AI_NUMERICHOST,
Family: windows.AF_UNSPEC,
Socktype: windows.SOCK_DGRAM,
Protocol: windows.IPPROTO_UDP,
}
var addrinfo *windows.AddrinfoW
err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
if err != nil {
return nil, err
}
defer windows.FreeAddrInfoW(addrinfo)
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
return nil, windows.ERROR_INVALID_ADDRESS
}
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
}
func (*WinRingEndpoint) ClearSrc() {}
func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
return netip.Addr{}
}
func (e *WinRingEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
switch e.family {
case windows.AF_INET:
b := make([]byte, 0, 6)
b = append(b, e.data[2:6]...)
b = append(b, e.data[1], e.data[0])
return b
case windows.AF_INET6:
b := make([]byte, 0, 18)
b = append(b, e.data[6:22]...)
b = append(b, e.data[1], e.data[0])
return b
}
return nil
}
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
func (e *WinRingEndpoint) SrcToString() string {
return ""
}
func (ring *ringBuffer) CloseAndZero() {
if ring.cq != 0 {
winrio.CloseCompletionQueue(ring.cq)
ring.cq = 0
}
if ring.iocp != 0 {
windows.CloseHandle(ring.iocp)
ring.iocp = 0
}
if ring.id != 0 {
winrio.DeregisterBuffer(ring.id)
ring.id = 0
}
if ring.packets != 0 {
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
ring.packets = 0
}
ring.head = 0
ring.tail = 0
ring.isFull = false
}
func (bind *afWinRingBind) CloseAndZero() {
bind.rx.CloseAndZero()
bind.tx.CloseAndZero()
if bind.sock != 0 {
windows.CloseHandle(bind.sock)
bind.sock = 0
}
bind.blackhole = false
}
func (bind *WinRingBind) closeAndZero() {
bind.isOpen.Store(0)
bind.v4.CloseAndZero()
bind.v6.CloseAndZero()
}
func (ring *ringBuffer) Open() error {
var err error
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return err
}
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
if err != nil {
return err
}
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return err
}
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
if err != nil {
return err
}
return nil
}
func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
var err error
bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return nil, err
}
err = bind.rx.Open()
if err != nil {
return nil, err
}
err = bind.tx.Open()
if err != nil {
return nil, err
}
bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
if err != nil {
return nil, err
}
err = windows.Bind(bind.sock, sa)
if err != nil {
return nil, err
}
sa, err = windows.Getsockname(bind.sock)
if err != nil {
return nil, err
}
return sa, nil
}
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
bind.mu.Lock()
defer bind.mu.Unlock()
defer func() {
if err != nil {
bind.closeAndZero()
}
}()
if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
if err != nil {
return nil, 0, err
}
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
if err != nil {
return nil, 0, err
}
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
for i := 0; i < packetsPerRing; i++ {
err = bind.v4.InsertReceiveRequest()
if err != nil {
return nil, 0, err
}
err = bind.v6.InsertReceiveRequest()
if err != nil {
return nil, 0, err
}
}
bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}
func (bind *WinRingBind) Close() error {
bind.mu.RLock()
if bind.isOpen.Load() != 1 {
bind.mu.RUnlock()
return nil
}
bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
bind.mu.RUnlock()
bind.mu.Lock()
defer bind.mu.Unlock()
bind.closeAndZero()
return nil
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (bind *WinRingBind) BatchSize() int {
// TODO: implement batching in and out of the ring
return 1
}
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
func (bind *afWinRingBind) InsertReceiveRequest() error {
packet := bind.rx.Push()
dataBuffer := &winrio.Buffer{
Id: bind.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
Length: uint32(len(packet.data)),
}
addressBuffer := &winrio.Buffer{
Id: bind.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
bind.mu.Lock()
defer bind.mu.Unlock()
return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
bind.rx.mu.Lock()
defer bind.rx.mu.Unlock()
var err error
var count uint32
var results [1]winrio.Result
retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
procyield(1)
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
}
if count == 0 {
err = winrio.Notify(bind.rx.cq)
if err != nil {
return 0, nil, err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return 0, nil, err
}
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
if count == 0 {
return 0, nil, io.ErrNoProgress
}
}
bind.rx.Return(1)
err = bind.InsertReceiveRequest()
if err != nil {
return 0, nil, err
}
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
goto retry
}
if results[0].Status != 0 {
return 0, nil, windows.Errno(results[0].Status)
}
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
ep := packet.addr
n := copy(buf, packet.data[:results[0].BytesTransferred])
return n, &ep, nil
}
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
}
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
}
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
if isOpen.Load() != 1 {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
return io.ErrShortBuffer
}
bind.tx.mu.Lock()
defer bind.tx.mu.Unlock()
var results [packetsPerRing]winrio.Result
count := winrio.DequeueCompletion(bind.tx.cq, results[:])
if count == 0 && bind.tx.isFull {
err := winrio.Notify(bind.tx.cq)
if err != nil {
return err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return err
}
if isOpen.Load() != 1 {
return net.ErrClosed
}
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
if count == 0 {
return io.ErrNoProgress
}
}
if count > 0 {
bind.tx.Return(count)
}
packet := bind.tx.Push()
packet.addr = *nend
copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{
Id: bind.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
Length: uint32(len(buf)),
}
addressBuffer := &winrio.Buffer{
Id: bind.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
bind.mu.Lock()
defer bind.mu.Unlock()
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
for _, buf := range bufs {
switch nend.family {
case windows.AF_INET:
if bind.v4.blackhole {
continue
}
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
case windows.AF_INET6:
if bind.v6.blackhole {
continue
}
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
}
}
return nil
}
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
s.mu.Lock()
defer s.mu.Unlock()
sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
s.blackhole4 = blackhole
return nil
}
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
s.mu.Lock()
defer s.mu.Unlock()
sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
s.blackhole6 = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
if err != nil {
return err
}
bind.v4.blackhole = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
if err != nil {
return err
}
bind.v6.blackhole = blackhole
return nil
}
func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
const IP_UNICAST_IF = 31
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
var bytes [4]byte
binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
if err != nil {
return err
}
return nil
}
func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
const IPV6_UNICAST_IF = 31
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
}

View file

@ -1,136 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bindtest
import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type ChannelBind struct {
rx4, tx4 *chan []byte
rx6, tx6 *chan []byte
closeSignal chan bool
source4, source6 ChannelEndpoint
target4, target6 ChannelEndpoint
}
type ChannelEndpoint uint16
var (
_ conn.Bind = (*ChannelBind)(nil)
_ conn.Endpoint = (*ChannelEndpoint)(nil)
)
func NewChannelBinds() [2]conn.Bind {
arx4 := make(chan []byte, 8192)
brx4 := make(chan []byte, 8192)
arx6 := make(chan []byte, 8192)
brx6 := make(chan []byte, 8192)
var binds [2]ChannelBind
binds[0].rx4 = &arx4
binds[0].tx4 = &brx4
binds[1].rx4 = &brx4
binds[1].tx4 = &arx4
binds[0].rx6 = &arx6
binds[0].tx6 = &brx6
binds[1].rx6 = &brx6
binds[1].tx6 = &arx6
binds[0].target4 = ChannelEndpoint(1)
binds[1].target4 = ChannelEndpoint(2)
binds[0].target6 = ChannelEndpoint(3)
binds[1].target6 = ChannelEndpoint(4)
binds[0].source4 = binds[1].target4
binds[0].source6 = binds[1].target6
binds[1].source4 = binds[0].target4
binds[1].source6 = binds[0].target6
return [2]conn.Bind{&binds[0], &binds[1]}
}
func (c ChannelEndpoint) ClearSrc() {}
func (c ChannelEndpoint) SrcToString() string { return "" }
func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
fns = append(fns, c.makeReceiveFunc(*c.rx4))
fns = append(fns, c.makeReceiveFunc(*c.rx6))
if rand.Uint32()&1 == 0 {
return fns, uint16(c.source4), nil
} else {
return fns, uint16(c.source6), nil
}
}
func (c *ChannelBind) Close() error {
if c.closeSignal != nil {
select {
case <-c.closeSignal:
default:
close(c.closeSignal)
}
}
return nil
}
func (c *ChannelBind) BatchSize() int { return 1 }
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
return 0, net.ErrClosed
case rx := <-ch:
copied := copy(bufs[0], rx)
sizes[0] = copied
eps[0] = c.target6
return 1, nil
}
}
}
func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
for _, b := range bufs {
select {
case <-c.closeSignal:
return net.ErrClosed
default:
bc := make([]byte, len(b))
copy(bc, b)
if ep.(ChannelEndpoint) == c.target4 {
*c.tx4 <- bc
} else if ep.(ChannelEndpoint) == c.target6 {
*c.tx6 <- bc
} else {
return os.ErrInvalid
}
}
}
return nil
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return ChannelEndpoint(addr.Port()), nil
}

View file

@ -1,34 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return -1, err
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return -1, err
}
return
}
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return -1, err
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return -1, err
}
return
}

View file

@ -1,133 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
package conn
import (
"errors"
"fmt"
"net/netip"
"reflect"
"runtime"
"strings"
)
const (
IdealBatchSize = 128 // maximum number of packets handled per read and write
)
// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
// depending on the platform-specific implementation.
type Bind interface {
// Open puts the Bind into a listening state on a given port and reports the actual
// port that it bound to. Passing zero results in a random selection.
// fns is the set of functions that will be called to receive packets.
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
// Close closes the Bind listener.
// All fns returned by Open must return net.ErrClosed after a call to Close.
Close() error
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
// Send writes one or more packets in bufs to address ep. The length of
// bufs must not exceed BatchSize().
Send(bufs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
// BatchSize is the number of buffers expected to be passed to
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
// tied to a single network interface. Used by wireguard-windows.
type BindSocketToInterface interface {
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
}
// PeekLookAtSocketFd is implemented by Bind objects that support having their
// file descriptor peeked at. Used by wireguard-android.
type PeekLookAtSocketFd interface {
PeekLookAtSocketFd4() (fd int, err error)
PeekLookAtSocketFd6() (fd int, err error)
}
// An Endpoint maintains the source/destination caching for a peer.
//
// dst: the remote address of a peer ("endpoint" in uapi terminology)
// src: the local address from which datagrams originate going to the peer
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() netip.Addr
SrcIP() netip.Addr
}
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)
func (fn ReceiveFunc) PrettyName() string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// 0. cheese/taco.beansIPv6.func12.func21218-fm
name = strings.TrimSuffix(name, "-fm")
// 1. cheese/taco.beansIPv6.func12.func21218
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
name = name[idx+1:]
// 2. taco.beansIPv6.func12.func21218
}
for {
var idx int
for idx = len(name) - 1; idx >= 0; idx-- {
if name[idx] < '0' || name[idx] > '9' {
break
}
}
if idx == len(name)-1 {
break
}
const dotFunc = ".func"
if !strings.HasSuffix(name[:idx+1], dotFunc) {
break
}
name = name[:idx+1-len(dotFunc)]
// 3. taco.beansIPv6.func12
// 4. taco.beansIPv6
}
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
name = name[idx+1:]
// 5. beansIPv6
}
if name == "" {
return fmt.Sprintf("%p", fn)
}
if strings.HasSuffix(name, "IPv4") {
return "v4"
}
if strings.HasSuffix(name, "IPv6") {
return "v6"
}
return name
}

View file

@ -1,24 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"testing"
)
func TestPrettyName(t *testing.T) {
var (
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
)
const want = "TestPrettyName"
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
if got := recvFunc.PrettyName(); got != want {
t.Errorf("PrettyName() = %v, want %v", got, want)
}
})
}

View file

@ -1,43 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"syscall"
)
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
// the max supported by a default configuration of macOS. Some platforms will
// silently clamp the value to other maximums, such as linux clamping to
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
// around this limitation)
const socketBufferSize = 7 << 20
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
}

View file

@ -1,61 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"runtime"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
// fail silently - the result of failure is lower performance on very fast
// links or high latency links.
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
// Set up to *mem_max
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
// Set beyond *mem_max if CAP_NET_ADMIN
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
})
},
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
if runtime.GOOS != "android" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
}
case "udp6":
c.Control(func(fd uintptr) {
if runtime.GOOS != "android" {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

View file

@ -1,35 +0,0 @@
//go:build !windows && !linux && !wasm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
})
},
func(network, address string, c syscall.RawConn) error {
var err error
if network == "udp6" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
}
return err
},
)
}

View file

@ -1,23 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/windows"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
})
},
)
}

View file

@ -1,10 +0,0 @@
//go:build !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func NewDefaultBind() Bind { return NewStdNetBind() }

View file

@ -1,12 +0,0 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(err error) bool {
return false
}

View file

@ -1,28 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func errShouldDisableUDPGSO(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
// It occurs when the peer mtu + wg headers is greater than path mtu.
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
}
return false
}

View file

@ -1,15 +0,0 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
return
}

View file

@ -1,31 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"golang.org/x/sys/unix"
)
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
rc, err := conn.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = errSyscall == nil
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
// use setsockopt workaround
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
rxOffload = errSyscall == nil
})
if err != nil {
return false, false
}
return txOffload, rxOffload
}

View file

@ -1,21 +0,0 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
func setGSOSize(control *[]byte, gsoSize uint16) {
}
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
// offloading control data.
const gsoControlSize = 0

View file

@ -1,65 +0,0 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"unsafe"
"golang.org/x/sys/unix"
)
const (
sizeOfGSOData = 2
)
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
var gso uint16
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
// data in control untouched.
func setGSOSize(control *[]byte, gsoSize uint16) {
existingLen := len(*control)
avail := cap(*control) - existingLen
space := unix.CmsgSpace(sizeOfGSOData)
if avail < space {
return
}
*control = (*control)[:cap(*control)]
gsoControl := (*control)[existingLen:]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
*control = (*control)[:existingLen+space]
}
// gsoControlSize returns the recommended buffer size for pooling UDP
// offloading control data.
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

View file

@ -1,12 +0,0 @@
//go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func (s *StdNetBind) SetMark(mark uint32) error {
return nil
}

View file

@ -1,42 +0,0 @@
//go:build !linux || android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net/netip"
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return ""
}
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
// {get,set}srcControl feature set, but use alternatively named flags and need
// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
const stickyControlSize = 0
const StdNetSupportsStickySockets = false

View file

@ -1,112 +0,0 @@
//go:build linux && !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
func (e *StdNetEndpoint) SrcIP() netip.Addr {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return netip.AddrFrom4(info.Spec_dst)
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
// TODO: set zone. in order to do so we need to check if the address is
// link local, and if it is perform a syscall to turn the ifindex into a
// zone string because netip uses string zones.
return netip.AddrFrom16(info.Addr)
}
return netip.Addr{}
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
switch len(e.src) {
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return info.Ifindex
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
return int32(info.Ifindex)
}
return 0
}
func (e *StdNetEndpoint) SrcToString() string {
return e.SrcIP().String()
}
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
}
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
copy(ep.src, hdrBuf)
copy(ep.src[unix.CmsgLen(0):], data)
return
}
}
}
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
if cap(*control) < len(ep.src) {
return
}
*control = (*control)[:0]
*control = append(*control, ep.src...)
}
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true

View file

@ -1,266 +0,0 @@
//go:build linux && !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"net"
"net/netip"
"runtime"
"testing"
"unsafe"
"golang.org/x/sys/unix"
)
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
var buf []byte
if addr.Is4() {
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
}
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
info := unix.Inet4Pktinfo{
Ifindex: ifidx,
Spec_dst: addr.As4(),
}
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
} else {
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
hdr := unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
}
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
info := unix.Inet6Pktinfo{
Ifindex: uint32(ifidx),
Addr: addr.As16(),
}
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
}
ep.src = buf
}
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IP {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IP_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
t.Errorf("unexpected address: %v", info.Spec_dst)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("IPv6", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
setSrc(ep, netip.MustParseAddr("::1"), 5)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IPV6 {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IPV6_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Addr != ep.SrcIP().As16() {
t.Errorf("unexpected address: %v", info.Addr)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
hdr.Len = 3
setSrcControl(&control, &StdNetEndpoint{})
if len(control) != 0 {
t.Errorf("unexpected control: %v", control)
}
})
}
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
var control []byte
ep := &StdNetEndpoint{}
setSrc(ep, netip.MustParseAddr("::1"), 5)
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 0 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("Multiple", func(t *testing.T) {
zeroControl := make([]byte, unix.CmsgSpace(0))
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
zeroHdr.SetLen(unix.CmsgLen(0))
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
combined := make([]byte, 0)
combined = append(combined, zeroControl...)
combined = append(combined, control...)
ep := &StdNetEndpoint{}
getSrcFromControl(combined, ep)
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.SrcIfidx() != 5 {
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
}
func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IP_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IPV6_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
}

View file

@ -1,254 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package winrio
import (
"log"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
const (
MsgDontNotify = 1
MsgDefer = 2
MsgWaitAll = 4
MsgCommitOnly = 8
MaxCqSize = 0x8000000
invalidBufferId = 0xFFFFFFFF
invalidCq = 0
invalidRq = 0
corruptCq = 0xFFFFFFFF
)
var extensionFunctionTable struct {
cbSize uint32
rioReceive uintptr
rioReceiveEx uintptr
rioSend uintptr
rioSendEx uintptr
rioCloseCompletionQueue uintptr
rioCreateCompletionQueue uintptr
rioCreateRequestQueue uintptr
rioDequeueCompletion uintptr
rioDeregisterBuffer uintptr
rioNotify uintptr
rioRegisterBuffer uintptr
rioResizeCompletionQueue uintptr
rioResizeRequestQueue uintptr
}
type Cq uintptr
type Rq uintptr
type BufferId uintptr
type Buffer struct {
Id BufferId
Offset uint32
Length uint32
}
type Result struct {
Status int32
BytesTransferred uint32
SocketContext uint64
RequestContext uint64
}
type notificationCompletionType uint32
const (
eventCompletion notificationCompletionType = 1
iocpCompletion notificationCompletionType = 2
)
type eventNotificationCompletion struct {
completionType notificationCompletionType
event windows.Handle
notifyReset uint32
}
type iocpNotificationCompletion struct {
completionType notificationCompletionType
iocp windows.Handle
key uintptr
overlapped *windows.Overlapped
}
var (
initialized sync.Once
available bool
)
func Initialize() bool {
initialized.Do(func() {
var (
err error
socket windows.Handle
cq Cq
)
defer func() {
if err == nil {
return
}
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
return
}
log.Printf("Registered I/O is unavailable: %v", err)
}()
socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return
}
defer windows.CloseHandle(socket)
WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
ob := uint32(0)
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
&ob, nil, 0)
if err != nil {
return
}
// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
var iocp windows.Handle
iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return
}
defer windows.CloseHandle(iocp)
var overlapped windows.Overlapped
cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
if err != nil {
return
}
defer CloseCompletionQueue(cq)
_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
if err != nil {
return
}
available = true
})
return available
}
func Socket(af, typ, proto int32) (windows.Handle, error) {
return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
}
func CloseCompletionQueue(cq Cq) {
_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
}
func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
notificationCompletion := &eventNotificationCompletion{
completionType: eventCompletion,
event: event,
}
if notifyReset {
notificationCompletion.notifyReset = 1
}
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
notificationCompletion := &iocpNotificationCompletion{
completionType: iocpCompletion,
iocp: iocp,
key: key,
overlapped: overlapped,
}
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
if ret == invalidRq {
return 0, err
}
return Rq(ret), nil
}
func DequeueCompletion(cq Cq, results []Result) uint32 {
var array uintptr
if len(results) > 0 {
array = uintptr(unsafe.Pointer(&results[0]))
}
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
if ret == corruptCq {
panic("cq is corrupt")
}
return uint32(ret)
}
func DeregisterBuffer(id BufferId) {
_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
}
func RegisterBuffer(buffer []byte) (BufferId, error) {
var buf unsafe.Pointer
if len(buffer) > 0 {
buf = unsafe.Pointer(&buffer[0])
}
return RegisterPointer(buf, uint32(len(buffer)))
}
func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
if ret == invalidBufferId {
return 0, err
}
return BufferId(ret), nil
}
func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
if ret == 0 {
return err
}
return nil
}
func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
if ret == 0 {
return err
}
return nil
}
func Notify(cq Cq) error {
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
if ret != 0 {
return windows.Errno(ret)
}
return nil
}

View file

@ -1,201 +1,173 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"container/list"
"encoding/binary"
"errors" "errors"
"math/bits" "math/bits"
"net" "net"
"net/netip"
"sync" "sync"
"unsafe" "unsafe"
) )
type parentIndirection struct {
parentBit **trieEntry
parentBitType uint8
}
type trieEntry struct { type trieEntry struct {
peer *Peer cidr uint
child [2]*trieEntry child [2]*trieEntry
parent parentIndirection bits net.IP
cidr uint8 peer *Peer
bitAtByte uint8
bitAtShift uint8 // index of "branching" bit
bits []byte
perPeerElem *list.Element bit_at_byte uint
bit_at_shift uint
} }
func commonBits(ip1, ip2 []byte) uint8 { func isLittleEndian() bool {
one := uint32(1)
return *(*byte)(unsafe.Pointer(&one)) != 0
}
func swapU32(i uint32) uint32 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes32(i)
}
func swapU64(i uint64) uint64 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes64(i)
}
func commonBits(ip1 net.IP, ip2 net.IP) uint {
size := len(ip1) size := len(ip1)
if size == net.IPv4len { if size == net.IPv4len {
a := binary.BigEndian.Uint32(ip1) a := (*uint32)(unsafe.Pointer(&ip1[0]))
b := binary.BigEndian.Uint32(ip2) b := (*uint32)(unsafe.Pointer(&ip2[0]))
x := a ^ b x := *a ^ *b
return uint8(bits.LeadingZeros32(x)) return uint(bits.LeadingZeros32(swapU32(x)))
} else if size == net.IPv6len { } else if size == net.IPv6len {
a := binary.BigEndian.Uint64(ip1) a := (*uint64)(unsafe.Pointer(&ip1[0]))
b := binary.BigEndian.Uint64(ip2) b := (*uint64)(unsafe.Pointer(&ip2[0]))
x := a ^ b x := *a ^ *b
if x != 0 { if x != 0 {
return uint8(bits.LeadingZeros64(x)) return uint(bits.LeadingZeros64(swapU64(x)))
} }
a = binary.BigEndian.Uint64(ip1[8:]) a = (*uint64)(unsafe.Pointer(&ip1[8]))
b = binary.BigEndian.Uint64(ip2[8:]) b = (*uint64)(unsafe.Pointer(&ip2[8]))
x = a ^ b x = *a ^ *b
return 64 + uint8(bits.LeadingZeros64(x)) return 64 + uint(bits.LeadingZeros64(swapU64(x)))
} else { } else {
panic("Wrong size bit string") panic("Wrong size bit string")
} }
} }
func (node *trieEntry) addToPeerEntries() { func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
node.perPeerElem = node.peer.trieEntries.PushBack(node) if node == nil {
} return node
func (node *trieEntry) removeFromPeerEntries() {
if node.perPeerElem != nil {
node.peer.trieEntries.Remove(node.perPeerElem)
node.perPeerElem = nil
} }
}
func (node *trieEntry) choose(ip []byte) byte { // walk recursively
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
func (node *trieEntry) maskSelf() { node.child[0] = node.child[0].removeByPeer(p)
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) node.child[1] = node.child[1].removeByPeer(p)
for i := 0; i < len(mask); i++ {
node.bits[i] &= mask[i] if node.peer != p {
return node
} }
}
func (node *trieEntry) zeroizePointers() { // remove peer & merge
// Make the garbage collector's life slightly easier
node.peer = nil node.peer = nil
node.child[0] = nil if node.child[0] == nil {
node.child[1] = nil return node.child[1]
node.parent.parentBit = nil }
return node.child[0]
} }
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { func (node *trieEntry) choose(ip net.IP) byte {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
parent = node }
if parent.cidr == cidr {
exact = true func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
return
// at leaf
if node == nil {
return &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
}
// traverse deeper
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
node.peer = peer
return node
} }
bit := node.choose(ip) bit := node.choose(ip)
node = node.child[bit] node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
} }
return
}
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { // split node
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
parent: trie,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
node.addToPeerEntries()
*trie.parentBit = node
return
}
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
if exact {
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return
}
newNode := &trieEntry{ newNode := &trieEntry{
peer: peer, bits: ip,
bits: ip, peer: peer,
cidr: cidr, cidr: cidr,
bitAtByte: cidr / 8, bit_at_byte: cidr / 8,
bitAtShift: 7 - (cidr % 8), bit_at_shift: 7 - (cidr % 8),
} }
newNode.maskSelf()
newNode.addToPeerEntries()
var down *trieEntry cidr = min(cidr, common)
if node == nil {
down = *trie.parentBit // check for shorter prefix
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
parent := node
if newNode.cidr == cidr { if newNode.cidr == cidr {
bit := newNode.choose(down.bits) bit := newNode.choose(node.bits)
down.parent = parentIndirection{&newNode.child[bit], bit} newNode.child[bit] = node
newNode.child[bit] = down return newNode
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
} }
node = &trieEntry{ // create new parent for node & newNode
bits: append([]byte{}, newNode.bits...),
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
bit := node.choose(down.bits) parent := &trieEntry{
down.parent = parentIndirection{&node.child[bit], bit} bits: ip,
node.child[bit] = down peer: nil,
bit = node.choose(newNode.bits) cidr: cidr,
newNode.parent = parentIndirection{&node.child[bit], bit} bit_at_byte: cidr / 8,
node.child[bit] = newNode bit_at_shift: 7 - (cidr % 8),
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = node
} }
bit := parent.choose(ip)
parent.child[bit] = newNode
parent.child[bit^1] = node
return parent
} }
func (node *trieEntry) lookup(ip []byte) *Peer { func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer var found *Peer
size := uint8(len(ip)) size := uint(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr { for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil { if node.peer != nil {
found = node.peer found = node.peer
} }
if node.bitAtByte == size { if node.bit_at_byte == size {
break break
} }
bit := node.choose(ip) bit := node.choose(ip)
@ -204,91 +176,76 @@ func (node *trieEntry) lookup(ip []byte) *Peer {
return found return found
} }
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
if node.peer == p {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
results = append(results, net.IPNet{
Mask: mask,
IP: node.bits.Mask(mask),
})
}
results = node.child[0].entriesForPeer(p, results)
results = node.child[1].entriesForPeer(p, results)
return results
}
type AllowedIPs struct { type AllowedIPs struct {
IPv4 *trieEntry IPv4 *trieEntry
IPv6 *trieEntry IPv6 *trieEntry
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { allowed := make([]net.IPNet, 0, 10)
node := elem.Value.(*trieEntry) allowed = table.IPv4.entriesForPeer(peer, allowed)
a, _ := netip.AddrFromSlice(node.bits) allowed = table.IPv6.entriesForPeer(peer, allowed)
if !cb(netip.PrefixFrom(a, int(node.cidr))) { return allowed
return }
}
} func (table *AllowedIPs) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = nil
table.IPv6 = nil
} }
func (table *AllowedIPs) RemoveByPeer(peer *Peer) { func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
var next *list.Element table.IPv4 = table.IPv4.removeByPeer(peer)
for elem := peer.trieEntries.Front(); elem != nil; elem = next { table.IPv6 = table.IPv6.removeByPeer(peer)
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
}
} }
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
if prefix.Addr().Is6() { switch len(ip) {
ip := prefix.Addr().As16() case net.IPv6len:
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) table.IPv6 = table.IPv6.insert(ip, cidr, peer)
} else if prefix.Addr().Is4() { case net.IPv4len:
ip := prefix.Addr().As4() table.IPv4 = table.IPv4.insert(ip, cidr, peer)
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) default:
} else {
panic(errors.New("inserting unknown address type")) panic(errors.New("inserting unknown address type"))
} }
} }
func (table *AllowedIPs) Lookup(ip []byte) *Peer { func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
switch len(ip) { return table.IPv4.lookup(address)
case net.IPv6len: }
return table.IPv6.lookup(ip)
case net.IPv4len: func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
return table.IPv4.lookup(ip) table.mutex.RLock()
default: defer table.mutex.RUnlock()
panic(errors.New("looking up unknown address type")) return table.IPv6.lookup(address)
}
} }

View file

@ -1,28 +1,25 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"math/rand" "math/rand"
"net"
"net/netip"
"sort" "sort"
"testing" "testing"
) )
const ( const (
NumberOfPeers = 100 NumberOfPeers = 100
NumberOfPeerRemovals = 4 NumberOfAddresses = 250
NumberOfAddresses = 250 NumberOfTests = 10000
NumberOfTests = 10000
) )
type SlowNode struct { type SlowNode struct {
peer *Peer peer *Peer
cidr uint8 cidr uint
bits []byte bits []byte
} }
@ -40,7 +37,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i] r[i], r[j] = r[j], r[i]
} }
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
for _, t := range r { for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer t.peer = peer
@ -67,75 +64,68 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil return nil
} }
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { func TestTrieRandomIPv4(t *testing.T) {
n := 0 var trie *trieEntry
for _, x := range r { var slow SlowRouter
if x.peer != peer {
r[n] = x
n++
}
}
return r[:n]
}
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
var peers []*Peer var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1) rand.Seed(1)
for n := 0; n < NumberOfPeers; n++ { const AddressLength = 4
for n := 0; n < NumberOfPeers; n += 1 {
peers = append(peers, &Peer{}) peers = append(peers, &Peer{})
} }
for n := 0; n < NumberOfAddresses; n++ { for n := 0; n < NumberOfAddresses; n += 1 {
var addr4 [4]byte var addr [AddressLength]byte
rand.Read(addr4[:]) rand.Read(addr[:])
cidr := uint8(rand.Intn(32) + 1) cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Intn(NumberOfPeers) index := rand.Int() % NumberOfPeers
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) trie = trie.insert(addr[:], cidr, peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
} }
var p int for n := 0; n < NumberOfTests; n += 1 {
for p = 0; ; p++ { var addr [AddressLength]byte
for n := 0; n < NumberOfTests; n++ { rand.Read(addr[:])
var addr4 [4]byte peer1 := slow.Lookup(addr[:])
rand.Read(addr4[:]) peer2 := trie.lookup(addr[:])
peer1 := slow4.Lookup(addr4[:]) if peer1 != peer2 {
peer2 := allowedIPs.Lookup(addr4[:]) t.Error("Trie did not match naive implementation, for:", addr)
if peer1 != peer2 { }
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) }
} }
var addr6 [16]byte func TestTrieRandomIPv6(t *testing.T) {
rand.Read(addr6[:]) var trie *trieEntry
peer1 = slow6.Lookup(addr6[:]) var slow SlowRouter
peer2 = allowedIPs.Lookup(addr6[:]) var peers []*Peer
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) rand.Seed(1)
}
const AddressLength = 16
for n := 0; n < NumberOfPeers; n += 1 {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
} }
if p >= len(peers) || p >= NumberOfPeerRemovals {
break
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}
for ; p < len(peers); p++ {
allowedIPs.RemoveByPeer(peers[p])
}
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
} }
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -8,17 +8,40 @@ package device
import ( import (
"math/rand" "math/rand"
"net" "net"
"net/netip"
"testing" "testing"
) )
/* Todo: More comprehensive
*/
type testPairCommonBits struct { type testPairCommonBits struct {
s1 []byte s1 []byte
s2 []byte s2 []byte
match uint8 match uint
}
type testPairTrieInsert struct {
key []byte
cidr uint
peer *Peer
}
type testPairTrieLookup struct {
key []byte
peer *Peer
}
func printTrie(t *testing.T, p *trieEntry) {
if p == nil {
return
}
t.Log(p)
printTrie(t, p.child[0])
printTrie(t, p.child[1])
} }
func TestCommonBits(t *testing.T) { func TestCommonBits(t *testing.T) {
tests := []testPairCommonBits{ tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@ -39,28 +62,27 @@ func TestCommonBits(t *testing.T) {
} }
} }
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
var trie *trieEntry var trie *trieEntry
var peers []*Peer var peers []*Peer
root := parentIndirection{&trie, 2}
rand.Seed(1) rand.Seed(1)
const AddressLength = 4 const AddressLength = 4
for n := 0; n < peerNumber; n++ { for n := 0; n < peerNumber; n += 1 {
peers = append(peers, &Peer{}) peers = append(peers, &Peer{})
} }
for n := 0; n < addressNumber; n++ { for n := 0; n < addressNumber; n += 1 {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8)) cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber index := rand.Int() % peerNumber
root.insert(addr[:], cidr, peers[index]) trie = trie.insert(addr[:], cidr, peers[index])
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n += 1 {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
trie.lookup(addr[:]) trie.lookup(addr[:])
@ -95,21 +117,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var allowedIPs AllowedIPs var trie *trieEntry
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
} }
assertEQ := func(peer *Peer, a, b, c, d byte) { assertEQ := func(peer *Peer, a, b, c, d byte) {
p := allowedIPs.Lookup([]byte{a, b, c, d}) p := trie.lookup([]byte{a, b, c, d})
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }
} }
assertNEQ := func(peer *Peer, a, b, c, d byte) { assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := allowedIPs.Lookup([]byte{a, b, c, d}) p := trie.lookup([]byte{a, b, c, d})
if p == peer { if p == peer {
t.Error("Assert NEQ failed") t.Error("Assert NEQ failed")
} }
@ -151,7 +173,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0) assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0) assertEQ(a, 255, 0, 0, 0)
allowedIPs.RemoveByPeer(a) trie = trie.removeByPeer(a)
assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0)
@ -159,21 +181,12 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0)
allowedIPs.RemoveByPeer(a) trie = nil
allowedIPs.RemoveByPeer(b)
allowedIPs.RemoveByPeer(c)
allowedIPs.RemoveByPeer(d)
allowedIPs.RemoveByPeer(e)
allowedIPs.RemoveByPeer(g)
allowedIPs.RemoveByPeer(h)
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Expected removing all the peers to empty trie, but it did not")
}
insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24) insert(a, 192, 168, 0, 0, 24)
allowedIPs.RemoveByPeer(a) trie = trie.removeByPeer(a)
assertNEQ(a, 192, 168, 0, 1) assertNEQ(a, 192, 168, 0, 1)
} }
@ -191,7 +204,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var allowedIPs AllowedIPs var trie *trieEntry
expand := func(a uint32) []byte { expand := func(a uint32) []byte {
var out [4]byte var out [4]byte
@ -202,13 +215,13 @@ func TestTrieIPv6(t *testing.T) {
return out[:] return out[:]
} }
insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
var addr []byte var addr []byte
addr = append(addr, expand(a)...) addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) trie = trie.insert(addr, cidr, peer)
} }
assertEQ := func(peer *Peer, a, b, c, d uint32) { assertEQ := func(peer *Peer, a, b, c, d uint32) {
@ -217,7 +230,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
p := allowedIPs.Lookup(addr) p := trie.lookup(addr)
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }

View file

@ -1,24 +1,23 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import "errors"
"errors"
"github.com/amnezia-vpn/amneziawg-go/conn"
)
type DummyDatagram struct { type DummyDatagram struct {
msg []byte msg []byte
endpoint conn.Endpoint endpoint Endpoint
world bool // better type
} }
type DummyBind struct { type DummyBind struct {
in6 chan DummyDatagram in6 chan DummyDatagram
ou6 chan DummyDatagram
in4 chan DummyDatagram in4 chan DummyDatagram
ou4 chan DummyDatagram
closed bool closed bool
} }
@ -26,21 +25,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil return nil
} }
func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) { func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
datagram, ok := <-b.in6 datagram, ok := <-b.in6
if !ok { if !ok {
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
} }
copy(buf, datagram.msg) copy(buff, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil return len(datagram.msg), datagram.endpoint, nil
} }
func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) { func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
datagram, ok := <-b.in4 datagram, ok := <-b.in4
if !ok { if !ok {
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
} }
copy(buf, datagram.msg) copy(buff, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil return len(datagram.msg), datagram.endpoint, nil
} }
@ -51,6 +50,6 @@ func (b *DummyBind) Close() error {
return nil return nil
} }
func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error { func (b *DummyBind) Send(buff []byte, end Endpoint) error {
return nil return nil
} }

44
device/boundif_android.go Normal file
View file

@ -0,0 +1,44 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import "errors"
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
nb, ok := device.net.bind.(*nativeBind)
if !ok {
return 0, errors.New("no socket exists")
}
sysconn, err := nb.ipv4.SyscallConn()
if err != nil {
return
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return
}
return
}
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
nb, ok := device.net.bind.(*nativeBind)
if !ok {
return 0, errors.New("no socket exists")
}
sysconn, err := nb.ipv6.SyscallConn()
if err != nil {
return
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return
}
return
}

62
device/boundif_windows.go Normal file
View file

@ -0,0 +1,62 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"encoding/binary"
"errors"
"unsafe"
"golang.org/x/sys/windows"
)
const (
sockoptIP_UNICAST_IF = 31
sockoptIPV6_UNICAST_IF = 31
)
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
if device.net.bind == nil {
return errors.New("Bind is not yet initialized")
}
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
return nil
}
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
return nil
}

View file

@ -1,137 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"runtime"
"sync"
)
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
// An outboundQueue is ref-counted using its wg field.
// An outboundQueue created with newOutboundQueue has one reference.
// Every additional writer must call wg.Add(1).
// Every completed writer must call wg.Done().
// When no further writers will be added,
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
c chan *QueueOutboundElementsContainer
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
c chan *QueueInboundElementsContainer
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
// A handshakeQueue is similar to an outboundQueue; see those docs.
type handshakeQueue struct {
c chan QueueHandshakeElement
wg sync.WaitGroup
}
func newHandshakeQueue() *handshakeQueue {
q := &handshakeQueue{
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
type autodrainingInboundQueue struct {
c chan *QueueInboundElementsContainer
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
// It is useful in cases in which is it hard to manage the lifetime of the channel.
// The returned channel must not be closed. Senders should signal shutdown using
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
}
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
default:
return
}
}
}
type autodrainingOutboundQueue struct {
c chan *QueueOutboundElementsContainer
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
// It is useful in cases in which is it hard to manage the lifetime of the channel.
// The returned channel must not be closed. Senders should signal shutdown using
// some other means, such as sending a sentinel nil values.
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
}
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
}
}

187
device/conn.go Normal file
View file

@ -0,0 +1,187 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"net"
"strings"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
const (
ConnRoutineNumber = 2
)
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/
type Bind interface {
SetMark(value uint32) error
ReceiveIPv6(buff []byte) (int, Endpoint, error)
ReceiveIPv4(buff []byte) (int, Endpoint, error)
Send(buff []byte, end Endpoint) error
Close() error
}
/* An Endpoint maintains the source/destination caching for a peer
*
* dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer
*/
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp.Get() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if device.isUp.Get() {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port, device)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
// start receiving routines
device.net.starting.Add(ConnRoutineNumber)
device.net.stopping.Add(ConnRoutineNumber)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := unsafeCloseBind(device)
device.net.Unlock()
return err
}

170
device/conn_default.go Normal file
View file

@ -0,0 +1,170 @@
// +build !linux android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"net"
"os"
"syscall"
)
/* This code is meant to be a temporary solution
* on platforms for which the sticky socket / source caching behavior
* has not yet been implemented.
*
* See conn_linux.go for an implementation on the linux platform.
*/
type nativeBind struct {
ipv4 *net.UDPConn
ipv6 *net.UDPConn
}
type NativeEndpoint net.UDPAddr
var _ Bind = (*nativeBind)(nil)
var _ Endpoint = (*NativeEndpoint)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s)
return (*NativeEndpoint)(addr), err
}
func (_ *NativeEndpoint) ClearSrc() {}
func (e *NativeEndpoint) DstIP() net.IP {
return (*net.UDPAddr)(e).IP
}
func (e *NativeEndpoint) SrcIP() net.IP {
return nil // not supported
}
func (e *NativeEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e)
out := addr.IP.To4()
if out == nil {
out = addr.IP
}
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
}
func (e *NativeEndpoint) DstToString() string {
return (*net.UDPAddr)(e).String()
}
func (e *NativeEndpoint) SrcToString() string {
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
// listen
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// retrieve port
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func extractErrno(err error) error {
opErr, ok := err.(*net.OpError)
if !ok {
return nil
}
syscallErr, ok := opErr.Err.(*os.SyscallError)
if !ok {
return nil
}
return syscallErr.Err
}
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
var err error
var bind nativeBind
port := int(uport)
bind.ipv4, port, err = listenNet("udp4", port)
if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
return nil, 0, err
}
bind.ipv6, port, err = listenNet("udp6", port)
if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
bind.ipv4.Close()
bind.ipv4 = nil
return nil, 0, err
}
return &bind, uint16(port), nil
}
func (bind *nativeBind) Close() error {
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
}
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
}
if err1 != nil {
return err1
}
return err2
}
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
if bind.ipv6 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend := endpoint.(*NativeEndpoint)
if nend.IP.To4() != nil {
if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err
}

757
device/conn_linux.go Normal file
View file

@ -0,0 +1,757 @@
// +build !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
package device
import (
"errors"
"net"
"strconv"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
const (
FD_ERR = -1
)
type IPv4Source struct {
src [4]byte
ifindex int32
}
type IPv6Source struct {
src [16]byte
//ifindex belongs in dst.ZoneId
}
type NativeEndpoint struct {
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool
}
func (endpoint *NativeEndpoint) src4() *IPv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *NativeEndpoint) src6() *IPv6Source {
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
}
type nativeBind struct {
sock4 int
sock6 int
netlinkSock int
netlinkCancel *rwcancel.RWCancel
lastMark uint32
}
var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = (*nativeBind)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint
addr, err := parseEndpoint(s)
if err != nil {
return nil, err
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
dst := end.dst4()
end.isV6 = false
dst.Port = addr.Port
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return &end, nil
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
dst.Port = addr.Port
dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return &end, nil
}
return nil, errors.New("Invalid IP address")
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}
func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
var err error
var bind nativeBind
var newPort uint16
bind.netlinkSock, err = createNetlinkRouteSocket()
if err != nil {
return nil, 0, err
}
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
if err != nil {
unix.Close(bind.netlinkSock)
return nil, 0, err
}
go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if succesful
bind.sock6, newPort, err = create6(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
return nil, 0, err
}
} else {
port = newPort
}
// attempt ipv4 bind, update port if succesful
bind.sock4, newPort, err = create4(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
unix.Close(bind.sock6)
return nil, 0, err
}
} else {
port = newPort
}
if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
return nil, 0, errors.New("ipv4 and ipv6 not supported")
}
return &bind, port, nil
}
func (bind *nativeBind) SetMark(value uint32) error {
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
if bind.sock4 != -1 {
err := unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
bind.lastMark = value
return nil
}
func closeUnblock(fd int) error {
// shutdown to unblock readers and writers
unix.Shutdown(fd, unix.SHUT_RDWR)
return unix.Close(fd)
}
func (bind *nativeBind) Close() error {
var err1, err2, err3 error
if bind.sock6 != -1 {
err1 = closeUnblock(bind.sock6)
}
if bind.sock4 != -1 {
err2 = closeUnblock(bind.sock4)
}
err3 = bind.netlinkCancel.Cancel()
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
}
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
if bind.sock6 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
}
n, err := receive6(
bind.sock6,
buff,
&end,
)
return n, &end, err
}
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
if bind.sock4 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
}
n, err := receive4(
bind.sock4,
buff,
&end,
)
return n, &end, err
}
func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint)
if !nend.isV6 {
if bind.sock4 == -1 {
return syscall.EAFNOSUPPORT
}
return send4(bind.sock4, nend, buff)
} else {
if bind.sock6 == -1 {
return syscall.EAFNOSUPPORT
}
return send6(bind.sock6, nend, buff)
}
}
func (end *NativeEndpoint) SrcIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.src4().src[0],
end.src4().src[1],
end.src4().src[2],
end.src4().src[3],
)
} else {
return end.src6().src[:]
}
}
func (end *NativeEndpoint) DstIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
} else {
return end.dst6().Addr[:]
}
}
func (end *NativeEndpoint) DstToBytes() []byte {
if !end.isV6 {
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
} else {
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
}
}
func (end *NativeEndpoint) SrcToString() string {
return end.SrcIP().String()
}
func (end *NativeEndpoint) DstToString() string {
var udpAddr net.UDPAddr
udpAddr.IP = end.DstIP()
if !end.isV6 {
udpAddr.Port = end.dst4().Port
} else {
udpAddr.Port = end.dst6().Port
}
return udpAddr.String()
}
func (end *NativeEndpoint) ClearDst() {
for i := range end.dst {
end.dst[i] = 0
}
}
func (end *NativeEndpoint) ClearSrc() {
for i := range end.src {
end.src[i] = 0
}
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
}
if intr, err := net.InterfaceByName(zone); err == nil {
return uint32(intr.Index), nil
}
n, err := strconv.ParseUint(zone, 10, 32)
return uint32(n), err
}
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return FD_ERR, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return FD_ERR, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet4).Port
}
return fd, uint16(addr.Port), err
}
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return FD_ERR, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return FD_ERR, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet6).Port
}
return fd, uint16(addr.Port), err
}
func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
Spec_dst: end.src4().src,
Ifindex: end.src4().ifindex,
},
}
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
}
return err
}
func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.src6().src,
Ifindex: end.dst6().ZoneId,
},
}
if cmsg.pktinfo.Addr == [16]byte{} {
cmsg.pktinfo.Ifindex = 0
}
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
}
return err
}
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = false
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
*end.dst4() = *newDst4
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4().src = cmsg.pktinfo.Spec_dst
end.src4().ifindex = cmsg.pktinfo.Ifindex
}
return size, nil
}
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = true
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
*end.dst6() = *newDst6
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src6().src = cmsg.pktinfo.Addr
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
return size, nil
}
func (bind *nativeBind) routineRouteListener(device *Device) {
type peerEndpointPtr struct {
peer *Peer
endpoint *Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer unix.Close(bind.netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !bind.netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
pePtr.peer.Unlock()
break
}
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.RLock()
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
peer.RUnlock()
continue
}
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
peer.RUnlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
peer.endpoint.(*NativeEndpoint).dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
peer.endpoint.(*NativeEndpoint).src4().src,
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
uint32(bind.lastMark),
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
}
reqPeerLock.Unlock()
peer.RUnlock()
i++
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -12,8 +12,8 @@ import (
/* Specification constants */ /* Specification constants */
const ( const (
RekeyAfterMessages = (1 << 60) RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RejectAfterMessages = (1 << 64) - (1 << 13) - 1 RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120 RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90 RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5 RekeyTimeout = time.Second * 5
@ -35,6 +35,7 @@ const (
/* Implementation constants */ /* Implementation constants */
const ( const (
UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
MaxPeers = 1 << 16 // maximum number of configured peers MaxPeers = 1 << 16 // maximum number of configured peers
) )

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2]) return hmac.Equal(mac1[:], msg[smac1:smac2])
} }
func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
@ -119,6 +119,7 @@ func (st *CookieChecker) CreateReply(
recv uint32, recv uint32,
src []byte, src []byte,
) (*MessageCookieReply, error) { ) (*MessageCookieReply, error) {
st.RLock() st.RLock()
// refresh cookie secret // refresh cookie secret
@ -203,6 +204,7 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
if err != nil { if err != nil {
return false return false
} }
@ -213,6 +215,7 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
} }
func (st *CookieGenerator) AddMacs(msg []byte) { func (st *CookieGenerator) AddMacs(msg []byte) {
size := len(msg) size := len(msg)
smac2 := size - blake2s.Size128 smac2 := size - blake2s.Size128

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -10,6 +10,7 @@ import (
) )
func TestCookieMAC1(t *testing.T) { func TestCookieMAC1(t *testing.T) {
// setup generator / checker // setup generator / checker
var ( var (
@ -131,12 +132,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20 msg[5] ^= 0x20
srcBad1 := []byte{192, 168, 13, 37, 40, 1} srcBad1 := []byte{192, 168, 13, 37, 40, 01}
if checker.CheckMAC2(msg, srcBad1) { if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }
srcBad2 := []byte{192, 168, 13, 38, 40, 1} srcBad2 := []byte{192, 168, 13, 38, 40, 01}
if checker.CheckMAC2(msg, srcBad2) { if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -11,42 +11,37 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/conn" "golang.zx2c4.com/wireguard/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/ipc" "golang.zx2c4.com/wireguard/tun"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter" )
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun" const (
"github.com/tevino/abool/v2" DeviceRoutineNumberPerCPU = 3
DeviceRoutineNumberAdditional = 2
) )
type Device struct { type Device struct {
isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
log *Logger
// synchronized resources (locks acquired in order)
state struct { state struct {
// state holds the device's state. It is accessed atomically. starting sync.WaitGroup
// Use the device.deviceState method to read it.
// device.deviceState does not acquire the mutex, so it captures only a snapshot.
// During state transitions, the state variable is updated before the device itself.
// The state is thus either the current state of the device or
// the intended future state of the device.
// For example, while executing a call to Up, state will be deviceStateUp.
// There is no guarantee that that intended future state of the device
// will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only.
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup stopping sync.WaitGroup
// mu protects state changes.
sync.Mutex sync.Mutex
changing AtomicBool
current bool
} }
net struct { net struct {
starting sync.WaitGroup
stopping sync.WaitGroup stopping sync.WaitGroup
sync.RWMutex sync.RWMutex
bind conn.Bind // bind interface bind Bind // bind interface
netlinkCancel *rwcancel.RWCancel port uint16 // listening port
port uint16 // listening port fwmark uint32 // mark value (0 = disabled)
fwmark uint32 // mark value (0 = disabled)
brokenRoaming bool
} }
staticIdentity struct { staticIdentity struct {
@ -56,195 +51,153 @@ type Device struct {
} }
peers struct { peers struct {
sync.RWMutex // protects keyMap sync.RWMutex
keyMap map[NoisePublicKey]*Peer keyMap map[NoisePublicKey]*Peer
} }
rate struct { // unprotected / "self-synchronising resources"
underLoadUntil atomic.Int64
limiter ratelimiter.Ratelimiter
}
allowedips AllowedIPs allowedips AllowedIPs
indexTable IndexTable indexTable IndexTable
cookieChecker CookieChecker cookieChecker CookieChecker
rate struct {
underLoadUntil atomic.Value
limiter ratelimiter.Ratelimiter
}
pool struct { pool struct {
inboundElementsContainer *WaitPool messageBufferPool *sync.Pool
outboundElementsContainer *WaitPool messageBufferReuseChan chan *[MaxMessageSize]byte
messageBuffers *WaitPool inboundElementPool *sync.Pool
inboundElements *WaitPool inboundElementReuseChan chan *QueueInboundElement
outboundElements *WaitPool outboundElementPool *sync.Pool
outboundElementReuseChan chan *QueueOutboundElement
} }
queue struct { queue struct {
encryption *outboundQueue encryption chan *QueueOutboundElement
decryption *inboundQueue decryption chan *QueueInboundElement
handshake *handshakeQueue handshake chan QueueHandshakeElement
}
signals struct {
stop chan struct{}
} }
tun struct { tun struct {
device tun.Device device tun.Device
mtu atomic.Int32 mtu int32
} }
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
} }
type aSecCfgType struct { /* Converts the peer into a "zombie", which remains in the peer map,
isSet bool * but processes no packets and does not exists in the routing table.
junkPacketCount int *
junkPacketMinSize int * Must hold device.peers.Mutex
junkPacketMaxSize int */
initPacketJunkSize int func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
}
// deviceState represents the state of a Device.
// There are three states: down, up, closed.
// Transitions:
//
// down -----+
// ↑↓ ↓
// up -> closed
type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
const (
deviceStateDown deviceState = iota
deviceStateUp
deviceStateClosed
)
// deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState {
return deviceState(device.state.state.Load())
}
// isClosed reports whether the device is closed (or is closing).
// See device.state.state comments for how to interpret this value.
func (device *Device) isClosed() bool {
return device.deviceState() == deviceStateClosed
}
// isUp reports whether the device is up (or is attempting to come up).
// See device.state.state comments for how to interpret this value.
func (device *Device) isUp() bool {
return device.deviceState() == deviceStateUp
}
// Must hold device.peers.Lock()
func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets // stop routing and processing of packets
device.allowedips.RemoveByPeer(peer) device.allowedips.RemoveByPeer(peer)
peer.Stop() peer.Stop()
// remove from peer map // remove from peer map
delete(device.peers.keyMap, key) delete(device.peers.keyMap, key)
} }
// changeState attempts to change the device state to match want. func deviceUpdateState(device *Device) {
func (device *Device) changeState(want deviceState) (err error) {
device.state.Lock() // check if state already being updated (guard)
defer device.state.Unlock()
old := device.deviceState() if device.state.changing.Swap(true) {
if old == deviceStateClosed { return
// once closed, always closed
device.log.Verbosef("Interface closed, ignored requested state %s", want)
return nil
} }
switch want {
case old: // compare to current state of device
return nil
case deviceStateUp: device.state.Lock()
device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked() newIsUp := device.isUp.Get()
if err == nil {
if newIsUp == device.state.current {
device.state.changing.Set(false)
device.state.Unlock()
return
}
// change state of device
switch newIsUp {
case true:
if err := device.BindUpdate(); err != nil {
device.log.Error.Printf("Unable to update bind: %v\n", err)
device.isUp.Set(false)
break break
} }
fallthrough // up failed; bring the device all the way back down device.peers.RLock()
case deviceStateDown: for _, peer := range device.peers.keyMap {
device.state.state.Store(uint32(deviceStateDown)) peer.Start()
errDown := device.downLocked() if peer.persistentKeepaliveInterval > 0 {
if err == nil { peer.SendKeepalive()
err = errDown }
} }
} device.peers.RUnlock()
device.log.Verbosef(
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
return
}
// upLocked attempts to bring the device up and reports whether it succeeded. case false:
// The caller must hold device.state.mu and is responsible for updating device.state.state. device.BindClose()
func (device *Device) upLocked() error { device.peers.RLock()
if err := device.BindUpdate(); err != nil { for _, peer := range device.peers.keyMap {
device.log.Errorf("Unable to update bind: %v", err) peer.Stop()
return err
}
// The IPC set operation waits for peers to be created before calling Start() on them,
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
} }
} device.peers.RUnlock()
device.peers.RUnlock()
return nil
}
// downLocked attempts to bring the device down.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
func (device *Device) downLocked() error {
err := device.BindClose()
if err != nil {
device.log.Errorf("Bind close failed: %v", err)
} }
device.peers.RLock() // update state variables
for _, peer := range device.peers.keyMap {
peer.Stop() device.state.current = newIsUp
device.state.changing.Set(false)
device.state.Unlock()
// check for state change in the mean time
deviceUpdateState(device)
}
func (device *Device) Up() {
// closed device cannot be brought up
if device.isClosed.Get() {
return
} }
device.peers.RUnlock()
return err device.isUp.Set(true)
deviceUpdateState(device)
} }
func (device *Device) Up() error { func (device *Device) Down() {
return device.changeState(deviceStateUp) device.isUp.Set(false)
} deviceUpdateState(device)
func (device *Device) Down() error {
return device.changeState(deviceStateDown)
} }
func (device *Device) IsUnderLoad() bool { func (device *Device) IsUnderLoad() bool {
// check if currently under load // check if currently under load
now := time.Now() now := time.Now()
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
if underLoad { if underLoad {
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
return true return true
} }
// check if recently under load // check if recently under load
return device.rate.underLoadUntil.Load() > now.UnixNano()
until := device.rate.underLoadUntil.Load().(time.Time)
return until.After(now)
} }
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@ -271,9 +224,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
publicKey := sk.publicKey() publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap { for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equals(publicKey) { if peer.handshake.remoteStatic.Equals(publicKey) {
peer.handshake.mutex.RUnlock() unsafeRemovePeer(device, peer, key)
removePeerLocked(device, peer, key)
peer.handshake.mutex.RLock()
} }
} }
@ -285,11 +236,23 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations // do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap { for key, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer) if rmKey {
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
} else {
expiredPeers = append(expiredPeers, peer)
}
} }
for _, peer := range lockedPeers { for _, peer := range lockedPeers {
@ -302,61 +265,66 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil return nil
} }
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{}) device.isUp.Set(false)
device.isClosed.Set(false)
device.log = logger device.log = logger
device.net.bind = bind
device.tun.device = tunDevice device.tun.device = tunDevice
mtu, err := device.tun.device.MTU() mtu, err := device.tun.device.MTU()
if err != nil { if err != nil {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err) logger.Error.Println("Trouble determining MTU, assuming default:", err)
mtu = DefaultMTU mtu = DefaultMTU
} }
device.tun.mtu.Store(int32(mtu)) device.tun.mtu = int32(mtu)
device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init() device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{})
device.indexTable.Init() device.indexTable.Init()
device.allowedips.Reset()
device.PopulatePools() device.PopulatePools()
// create queues // create queues
device.queue.handshake = newHandshakeQueue() device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = newOutboundQueue() device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.decryption = newInboundQueue() device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals
device.signals.stop = make(chan struct{})
// prepare net
device.net.port = 0
device.net.bind = nil
// start workers // start workers
cpus := runtime.NumCPU() cpus := runtime.NumCPU()
device.state.starting.Wait()
device.state.stopping.Wait() device.state.stopping.Wait()
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
for i := 0; i < cpus; i++ { device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
go device.RoutineEncryption(i + 1) for i := 0; i < cpus; i += 1 {
go device.RoutineDecryption(i + 1) go device.RoutineEncryption()
go device.RoutineHandshake(i + 1) go device.RoutineDecryption()
go device.RoutineHandshake()
} }
device.state.stopping.Add(1) // RoutineReadFromTUN
device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
go device.RoutineReadFromTUN() go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader() go device.RoutineTUNEventReader()
return device device.state.starting.Wait()
}
// BatchSize returns the BatchSize for the device as a whole which is the max of return device
// the bind batch size and the tun batch size. The batch size reported by device
// is the size used to construct memory pools, and is the allowed batch size for
// the lifetime of the device.
func (device *Device) BatchSize() int {
size := device.net.bind.BatchSize()
dSize := device.tun.device.BatchSize()
if size < dSize {
size = dSize
}
return size
} }
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
@ -373,7 +341,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
peer, ok := device.peers.keyMap[key] peer, ok := device.peers.keyMap[key]
if ok { if ok {
removePeerLocked(device, peer, key) unsafeRemovePeer(device, peer, key)
} }
} }
@ -382,52 +350,67 @@ func (device *Device) RemoveAllPeers() {
defer device.peers.Unlock() defer device.peers.Unlock()
for key, peer := range device.peers.keyMap { for key, peer := range device.peers.keyMap {
removePeerLocked(device, peer, key) unsafeRemovePeer(device, peer, key)
} }
device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
} }
func (device *Device) FlushPacketQueues() {
for {
select {
case elem, ok := <-device.queue.decryption:
if ok {
elem.Drop()
}
case elem, ok := <-device.queue.encryption:
if ok {
elem.Drop()
}
case <-device.queue.handshake:
default:
return
}
}
}
func (device *Device) Close() { func (device *Device) Close() {
device.state.Lock() if device.isClosed.Swap(true) {
defer device.state.Unlock()
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
if device.isClosed() {
return return
} }
device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing") device.state.starting.Wait()
device.log.Info.Println("Device closing")
device.state.changing.Set(true)
device.state.Lock()
defer device.state.Unlock()
device.tun.device.Close() device.tun.device.Close()
device.downLocked() device.BindClose()
device.isUp.Set(false)
close(device.signals.stop)
// Remove peers before closing queues,
// because peers assume that queues are active.
device.RemoveAllPeers() device.RemoveAllPeers()
// We kept a reference to the encryption and decryption queues,
// in case we started any new peers that might write to them.
// No new peers are coming; we are done with these queues.
device.queue.encryption.wg.Done()
device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.state.stopping.Wait() device.state.stopping.Wait()
device.FlushPacketQueues()
device.rate.limiter.Close() device.rate.limiter.Close()
device.resetProtocol() device.state.changing.Set(false)
device.log.Info.Println("Interface closed")
device.log.Verbosef("Device closed")
close(device.closed)
} }
func (device *Device) Wait() chan struct{} { func (device *Device) Wait() chan struct{} {
return device.closed return device.signals.stop
} }
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
if !device.isUp() { if device.isClosed.Get() {
return return
} }
@ -442,366 +425,3 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
} }
device.peers.RUnlock() device.peers.RUnlock()
} }
// closeBindLocked closes the device's net.bind.
// The caller must hold the net mutex.
func closeBindLocked(device *Device) error {
var err error
netc := &device.net
if netc.netlinkCancel != nil {
netc.netlinkCancel.Cancel()
}
if netc.bind != nil {
err = netc.bind.Close()
}
netc.stopping.Wait()
return err
}
func (device *Device) Bind() conn.Bind {
device.net.Lock()
defer device.net.Unlock()
return device.net.bind
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := closeBindLocked(device); err != nil {
return err
}
// open new sockets
if !device.isUp() {
return nil
}
// bind to new port
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
// start receiving routines
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
go device.RoutineReceiveIncoming(batchSize, fn)
}
device.log.Verbosef("UDP bind has been updated")
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := closeBindLocked(device)
device.net.Unlock()
return err
}
func (device *Device) isAdvancedSecurityOn() bool {
return device.isASecOn.IsSet()
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
return err
}
isASecOn := false
device.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
}
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
if tempASecCfg.junkPacketCount != 0 {
isASecOn = true
}
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
if tempASecCfg.junkPacketMinSize != 0 {
isASecOn = true
}
if device.aSecCfg.junkPacketCount > 0 &&
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
device.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
)
}
} else {
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
}
if tempASecCfg.junkPacketMaxSize != 0 {
isASecOn = true
}
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
}
if tempASecCfg.initPacketJunkSize != 0 {
isASecOn = true
}
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
}
if tempASecCfg.responsePacketJunkSize != 0 {
isASecOn = true
}
if tempASecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = 1
}
if tempASecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = 2
}
if tempASecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = 3
}
if tempASecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = 4
}
isSameMap := map[uint32]bool{}
isSameMap[MessageInitiationType] = true
isSameMap[MessageResponseType] = true
isSameMap[MessageCookieReplyType] = true
isSameMap[MessageTransportType] = true
// size will be different if same values
if len(isSameMap) != 4 {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
)
}
}
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
if newInitSize == newResponseSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ; %w`,
newInitSize,
newResponseSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ`,
newInitSize,
newResponseSize,
)
}
} else {
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
return err
}

View file

@ -1,572 +1,68 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
/* Create two device instances and simulate full WireGuard interaction
* without network dependencies
*/
import ( import (
"bytes" "bytes"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net/netip"
"os"
"runtime"
"runtime/pprof"
"sync"
"sync/atomic"
"testing" "testing"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
) )
// uapiCfg returns a string that contains cfg formatted use with IpcSet. func TestDevice(t *testing.T) {
// cfg is a series of alternating key/value strings.
// uapiCfg exists because editors and humans like to insert
// whitespace into configs, which can cause failures, some of which are silent.
// For example, a leading blank newline causes the remainder
// of the config to be silently ignored.
func uapiCfg(cfg ...string) string {
if len(cfg)%2 != 0 {
panic("odd number of args to uapiReader")
}
buf := new(bytes.Buffer)
for i, s := range cfg {
buf.WriteString(s)
sep := byte('\n')
if i%2 == 0 {
sep = '='
}
buf.WriteByte(sep)
}
return buf.String()
}
// genConfigs generates a pair of configs that connect to each other. // prepare tun devices for generating traffic
// The configs use distinct, probably-usable ports.
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { tun1 := newDummyTUN("tun1")
var key1, key2 NoisePrivateKey tun2 := newDummyTUN("tun2")
_, err := rand.Read(key1[:])
_ = tun1
_ = tun2
// prepare endpoints
end1, err := CreateDummyEndpoint()
if err != nil { if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err) t.Error("failed to create endpoint:", err.Error())
} }
_, err = rand.Read(key2[:])
end2, err := CreateDummyEndpoint()
if err != nil { if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err) t.Error("failed to create endpoint:", err.Error())
} }
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg( _ = end1
"private_key", hex.EncodeToString(key1[:]), _ = end2
"listen_port", "0",
"replace_peers", "true", // create binds
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
)
return
} }
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { func randDevice(t *testing.T) *Device {
var key1, key2 NoisePrivateKey sk, err := newPrivateKey()
_, err := rand.Read(key1[:])
if err != nil { if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err) t.Fatal(err)
} }
_, err = rand.Read(key2[:]) tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}
func assertNil(t *testing.T, err error) {
if err != nil { if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err) t.Fatal(err)
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
)
return
}
// A testPair is a pair of testPeers.
type testPair [2]testPeer
// A testPeer is a peer used for testing.
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
ip netip.Addr
}
type SendDirection bool
const (
Ping SendDirection = true
Pong SendDirection = false
)
func (d SendDirection) String() string {
if d == Ping {
return "ping"
}
return "pong"
}
func (pair *testPair) Send(
tb testing.TB,
ping SendDirection,
done chan struct{},
) {
tb.Helper()
p0, p1 := pair[0], pair[1]
if !ping {
// pong is the new ping
p0, p1 = p1, p0
}
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
var err error
select {
case msgRecv := <-p0.tun.Inbound:
if !bytes.Equal(msg, msgRecv) {
err = fmt.Errorf("%s did not transit correctly", ping)
}
case <-timer.C:
err = fmt.Errorf("%s did not transit", ping)
case <-done:
}
if err != nil {
// The error may have occurred because the test is done.
select {
case <-done:
return
default:
}
// Real error.
tb.Error(err)
} }
} }
// genTestPair creates a testPair. func assertEqual(t *testing.T, a, b []byte) {
func genTestPair( if !bytes.Equal(a, b) {
tb testing.TB, t.Fatal(a, "!=", b)
realSocket, withASecurity bool,
) (pair testPair) {
var cfg, endpointCfg [2]string
if withASecurity {
cfg, endpointCfg = genASecurityConfigs(tb)
} else {
cfg, endpointCfg = genConfigs(tb)
}
var binds [2]conn.Bind
if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
} else {
binds = bindtest.NewChannelBinds()
}
// Bring up a ChannelTun for each config.
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
}
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close()
continue
}
if err := p.dev.Up(); err != nil {
tb.Errorf("failed to bring up device %d: %v", i, err)
p.dev.Close()
continue
}
endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
}
for i := range pair {
p := &pair[i]
if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
tb.Errorf("failed to configure device endpoint %d: %v", i, err)
p.dev.Close()
continue
}
// The device is ready. Close it when the test completes.
tb.Cleanup(p.dev.Close)
}
return
}
func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, false)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
func TestUpDown(t *testing.T) {
goroutineLeakCheck(t)
const itrials = 50
const otrials = 10
for n := 0; n < otrials; n++ {
pair := genTestPair(t, false, false)
for i := range pair {
for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
}
}
var wg sync.WaitGroup
wg.Add(len(pair))
for i := range pair {
go func(d *Device) {
defer wg.Done()
for i := 0; i < itrials; i++ {
if err := d.Up(); err != nil {
t.Errorf("failed up bring up device: %v", err)
}
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
if err := d.Down(); err != nil {
t.Errorf("failed to bring down device: %v", err)
}
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
}
}(pair[i].dev)
}
wg.Wait()
for i := range pair {
pair[i].dev.Up()
pair[i].dev.Close()
}
}
}
// TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t, true, false)
done := make(chan struct{})
const warmupIters = 10
var warmup sync.WaitGroup
warmup.Add(warmupIters)
go func() {
// Send data continuously back and forth until we're done.
// Note that we may continue to attempt to send data
// even after done is closed.
i := warmupIters
for ping := Ping; ; ping = !ping {
pair.Send(t, ping, done)
select {
case <-done:
return
default:
}
if i > 0 {
warmup.Done()
i--
}
}
}()
warmup.Wait()
applyCfg := func(cfg string) {
err := pair[0].dev.IpcSet(cfg)
if err != nil {
t.Fatal(err)
}
}
// Change persistent_keepalive_interval concurrently with tunnel use.
t.Run("persistentKeepaliveInterval", func(t *testing.T) {
var pub NoisePublicKey
for key := range pair[0].dev.peers.keyMap {
pub = key
break
}
cfg := uapiCfg(
"public_key", hex.EncodeToString(pub[:]),
"persistent_keepalive_interval", "1",
)
for i := 0; i < 1000; i++ {
applyCfg(cfg)
}
})
// Change private keys concurrently with tunnel use.
t.Run("privateKey", func(t *testing.T) {
bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
// Set iters to a large number like 1000 to flush out data races quickly.
// Don't leave it large. That can cause logical races
// in which the handshake is interleaved with key changes
// such that the private key appears to be unchanging but
// other state gets reset, which can cause handshake failures like
// "Received packet with invalid mac1".
const iters = 1
for i := 0; i < iters; i++ {
applyCfg(bad)
applyCfg(good)
}
})
// Perform bind updates and keepalive sends concurrently with tunnel use.
t.Run("bindUpdate and keepalive", func(t *testing.T) {
const iters = 10
for i := 0; i < iters; i++ {
for _, peer := range pair {
peer.dev.BindUpdate()
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
}
}
})
close(done)
}
func BenchmarkLatency(b *testing.B) {
pair := genTestPair(b, true, false)
// Establish a connection.
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
}
}
func BenchmarkThroughput(b *testing.B) {
pair := genTestPair(b, true, false)
// Establish a connection.
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
// Measure how long it takes to receive b.N packets,
// starting when we receive the first packet.
var recv atomic.Uint64
var elapsed time.Duration
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
var start time.Time
for {
<-pair[0].tun.Inbound
new := recv.Add(1)
if new == 1 {
start = time.Now()
}
// Careful! Don't change this to else if; b.N can be equal to 1.
if new == uint64(b.N) {
elapsed = time.Since(start)
return
}
}
}()
// Send packets as fast as we can until we've received enough.
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound
var sent uint64
for recv.Load() != uint64(b.N) {
sent++
pingc <- ping
}
wg.Wait()
b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
}
func BenchmarkUAPIGet(b *testing.B) {
pair := genTestPair(b, true, false)
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
pair[0].dev.IpcGetOperation(io.Discard)
}
}
func goroutineLeakCheck(t *testing.T) {
goroutines := func() (int, []byte) {
p := pprof.Lookup("goroutine")
b := new(bytes.Buffer)
p.WriteTo(b, 1)
return p.Count(), b.Bytes()
}
startGoroutines, startStacks := goroutines()
t.Cleanup(func() {
if t.Failed() {
return
}
// Give goroutines time to exit, if they need it.
for i := 0; i < 10000; i++ {
if runtime.NumGoroutine() <= startGoroutines {
return
}
time.Sleep(1 * time.Millisecond)
}
endGoroutines, endStacks := goroutines()
t.Logf("starting stacks:\n%s\n", startStacks)
t.Logf("ending stacks:\n%s\n", endStacks)
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
})
}
type fakeBindSized struct {
size int
}
func (b *fakeBindSized) Open(
port uint16,
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil
}
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
func (b *fakeBindSized) BatchSize() int { return b.size }
type fakeTUNDeviceSized struct {
size int
}
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func TestBatchSize(t *testing.T) {
d := Device{}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 1, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
} }
} }

View file

@ -1,16 +0,0 @@
// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
package device
import "strconv"
const _deviceState_name = "DownUpClosed"
var _deviceState_index = [...]uint8{0, 4, 6, 12}
func (i deviceState) String() string {
if i >= deviceState(len(_deviceState_index)-1) {
return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
}

View file

@ -1,49 +1,53 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"math/rand" "math/rand"
"net/netip" "net"
) )
type DummyEndpoint struct { type DummyEndpoint struct {
src, dst netip.Addr src [16]byte
dst [16]byte
} }
func CreateDummyEndpoint() (*DummyEndpoint, error) { func CreateDummyEndpoint() (*DummyEndpoint, error) {
var src, dst [16]byte var end DummyEndpoint
if _, err := rand.Read(src[:]); err != nil { if _, err := rand.Read(end.src[:]); err != nil {
return nil, err return nil, err
} }
_, err := rand.Read(dst[:]) _, err := rand.Read(end.dst[:])
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err return &end, err
} }
func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string { func (e *DummyEndpoint) SrcToString() string {
return netip.AddrPortFrom(e.SrcIP(), 1000).String() var addr net.UDPAddr
addr.IP = e.SrcIP()
addr.Port = 1000
return addr.String()
} }
func (e *DummyEndpoint) DstToString() string { func (e *DummyEndpoint) DstToString() string {
return netip.AddrPortFrom(e.DstIP(), 1000).String() var addr net.UDPAddr
addr.IP = e.DstIP()
addr.Port = 1000
return addr.String()
} }
func (e *DummyEndpoint) DstToBytes() []byte { func (e *DummyEndpoint) SrcToBytes() []byte {
out := e.DstIP().AsSlice() return e.src[:]
out = append(out, byte(1000&0xff))
out = append(out, byte((1000>>8)&0xff))
return out
} }
func (e *DummyEndpoint) DstIP() netip.Addr { func (e *DummyEndpoint) DstIP() net.IP {
return e.dst return e.dst[:]
} }
func (e *DummyEndpoint) SrcIP() netip.Addr { func (e *DummyEndpoint) SrcIP() net.IP {
return e.src return e.src[:]
} }

View file

@ -1,14 +1,14 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary"
"sync" "sync"
"unsafe"
) )
type IndexTableEntry struct { type IndexTableEntry struct {
@ -25,8 +25,7 @@ type IndexTable struct {
func randUint32() (uint32, error) { func randUint32() (uint32, error) {
var integer [4]byte var integer [4]byte
_, err := rand.Read(integer[:]) _, err := rand.Read(integer[:])
// Arbitrary endianness; both are intrinsified by the Go compiler. return *(*uint32)(unsafe.Pointer(&integer[0])), err
return binary.LittleEndian.Uint32(integer[:]), err
} }
func (table *IndexTable) Init() { func (table *IndexTable) Init() {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View file

@ -1,69 +0,0 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
device *Device
cha8Rand *v2.ChaCha8
}
func NewJunkCreator(d *Device) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
return nil, fmt.Errorf("Failed to create junk packet: %v", err)
}
junks = append(junks, junk)
}
return junks, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

View file

@ -1,124 +0,0 @@
package device
import (
"bytes"
"fmt"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
cfg, _ := genASecurityConfigs(t)
tun := tuntest.NewChannelTUN()
binds := bindtest.NewChannelBinds()
level := LogLevelVerbose
dev := NewDevice(
tun.TUN(),
binds[0],
NewLogger(level, ""),
)
if err := dev.IpcSet(cfg[0]); err != nil {
t.Errorf("failed to configure device %v", err)
dev.Close()
return junkCreator{}, err
}
jc, err := NewJunkCreator(dev)
if err != nil {
t.Errorf("failed to create junk creator %v", err)
dev.Close()
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
got, err := jc.createJunkPackets()
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
jc.device.Close()
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -20,7 +20,7 @@ type KDFTest struct {
t2 string t2 string
} }
func assertEquals(t *testing.T, a, b string) { func assertEquals(t *testing.T, a string, b string) {
if a != b { if a != b {
t.Fatal("expected", a, "=", b) t.Fatal("expected", a, "=", b)
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -8,10 +8,9 @@ package device
import ( import (
"crypto/cipher" "crypto/cipher"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/replay" "golang.zx2c4.com/wireguard/replay"
) )
/* Due to limitations in Go and /x/crypto there is currently /* Due to limitations in Go and /x/crypto there is currently
@ -22,10 +21,10 @@ import (
*/ */
type Keypair struct { type Keypair struct {
sendNonce atomic.Uint64 sendNonce uint64
send cipher.AEAD send cipher.AEAD
receive cipher.AEAD receive cipher.AEAD
replayFilter replay.Filter replayFilter replay.ReplayFilter
isInitiator bool isInitiator bool
created time.Time created time.Time
localIndex uint32 localIndex uint32
@ -36,7 +35,7 @@ type Keypairs struct {
sync.RWMutex sync.RWMutex
current *Keypair current *Keypair
previous *Keypair previous *Keypair
next atomic.Pointer[Keypair] next *Keypair
} }
func (kp *Keypairs) Current() *Keypair { func (kp *Keypairs) Current() *Keypair {

View file

@ -1,48 +1,59 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"io"
"io/ioutil"
"log" "log"
"os" "os"
) )
// A Logger provides logging for a Device.
// The functions are Printf-style functions.
// They must be safe for concurrent use.
// They do not require a trailing newline in the format.
// If nil, that level of logging will be silent.
type Logger struct {
Verbosef func(format string, args ...any)
Errorf func(format string, args ...any)
}
// Log levels for use with NewLogger.
const ( const (
LogLevelSilent = iota LogLevelSilent = iota
LogLevelError LogLevelError
LogLevelVerbose LogLevelInfo
LogLevelDebug
) )
// Function for use in Logger for discarding logged lines. type Logger struct {
func DiscardLogf(format string, args ...any) {} Debug *log.Logger
Info *log.Logger
Error *log.Logger
}
// NewLogger constructs a Logger that writes to stdout.
// It logs at the specified log level and above.
// It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger { func NewLogger(level int, prepend string) *Logger {
logger := &Logger{DiscardLogf, DiscardLogf} output := os.Stdout
logf := func(prefix string) func(string, ...any) { logger := new(Logger)
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
} logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
if level >= LogLevelVerbose { if level >= LogLevelDebug {
logger.Verbosef = logf("DEBUG") return output, output, output
} }
if level >= LogLevelError { if level >= LogLevelInfo {
logger.Errorf = logf("ERROR") return output, output, ioutil.Discard
} }
if level >= LogLevelError {
return output, ioutil.Discard, ioutil.Discard
}
return ioutil.Discard, ioutil.Discard, ioutil.Discard
}()
logger.Debug = log.New(logDebug,
"DEBUG: "+prepend,
log.Ldate|log.Ltime,
)
logger.Info = log.New(logInfo,
"INFO: "+prepend,
log.Ldate|log.Ltime,
)
logger.Error = log.New(logErr,
"ERROR: "+prepend,
log.Ldate|log.Ltime,
)
return logger return logger
} }

12
device/mark_default.go Normal file
View file

@ -0,0 +1,12 @@
// +build !linux,!openbsd,!freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
func (bind *nativeBind) SetMark(mark uint32) error {
return nil
}

View file

@ -1,11 +1,11 @@
//go:build linux || openbsd || freebsd // +build android openbsd freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package conn package device
import ( import (
"runtime" "runtime"
@ -26,13 +26,13 @@ func init() {
} }
} }
func (s *StdNetBind) SetMark(mark uint32) error { func (bind *nativeBind) SetMark(mark uint32) error {
var operr error var operr error
if fwmarkIoctl == 0 { if fwmarkIoctl == 0 {
return nil return nil
} }
if s.ipv4 != nil { if bind.ipv4 != nil {
fd, err := s.ipv4.SyscallConn() fd, err := bind.ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -46,8 +46,8 @@ func (s *StdNetBind) SetMark(mark uint32) error {
return err return err
} }
} }
if s.ipv6 != nil { if bind.ipv6 != nil {
fd, err := s.ipv6.SyscallConn() fd, err := bind.ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
} }

48
device/misc.go Normal file
View file

@ -0,0 +1,48 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync/atomic"
)
/* Atomic Boolean */
const (
AtomicFalse = int32(iota)
AtomicTrue
)
type AtomicBool struct {
int32
}
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.int32) == AtomicTrue
}
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
atomic.StoreInt32(&a.int32, flag)
}
func min(a, b uint) uint {
if a > b {
return b
}
return a
}

View file

@ -1,19 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
// though it will try to deal with it, and race maybe, if called after.
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.endpoint.Lock()
peer.endpoint.disableRoaming = peer.endpoint.val != nil
peer.endpoint.Unlock()
}
device.peers.RUnlock()
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -9,7 +9,6 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"errors"
"hash" "hash"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
@ -95,14 +94,9 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return return
} }
var errInvalidPublicKey = errors.New("invalid public key") func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk) apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk) ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk) curve25519.ScalarMult(&ss, ask, apk)
if isZero(ss[:]) { return ss
return ss, errInvalidPublicKey
}
return ss, nil
} }

View file

@ -1,50 +1,29 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n"
"github.com/amnezia-vpn/amneziawg-go/tai64n"
) )
type handshakeState int
const ( const (
handshakeZeroed = handshakeState(iota) HandshakeZeroed = iota
handshakeInitiationCreated HandshakeInitiationCreated
handshakeInitiationConsumed HandshakeInitiationConsumed
handshakeResponseCreated HandshakeResponseCreated
handshakeResponseConsumed HandshakeResponseConsumed
) )
func (hs handshakeState) String() string {
switch hs {
case handshakeZeroed:
return "handshakeZeroed"
case handshakeInitiationCreated:
return "handshakeInitiationCreated"
case handshakeInitiationConsumed:
return "handshakeInitiationConsumed"
case handshakeResponseCreated:
return "handshakeResponseCreated"
case handshakeResponseConsumed:
return "handshakeResponseConsumed"
default:
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
}
}
const ( const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@ -52,21 +31,21 @@ const (
WGLabelCookie = "cookie--" WGLabelCookie = "cookie--"
) )
var ( const (
MessageInitiationType uint32 = 1 MessageInitiationType = 1
MessageResponseType uint32 = 2 MessageResponseType = 2
MessageCookieReplyType uint32 = 3 MessageCookieReplyType = 3
MessageTransportType uint32 = 4 MessageTransportType = 4
) )
const ( const (
MessageInitiationSize = 148 // size of handshake initiation message MessageInitiationSize = 148 // size of handshake initation message
MessageResponseSize = 92 // size of response message MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceding content in transport message MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
) )
const ( const (
@ -75,10 +54,6 @@ const (
MessageTransportOffsetContent = 16 MessageTransportOffsetContent = 16
) )
var packetSizeToMsgType map[int]uint32
var msgTypeToJunkSize map[uint32]int
/* Type is an 8-bit field, followed by 3 nul bytes, /* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder * by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now) * we can treat these as a 32-bit unsigned int (for now)
@ -120,11 +95,11 @@ type MessageCookieReply struct {
} }
type Handshake struct { type Handshake struct {
state handshakeState state int
mutex sync.RWMutex mutex sync.RWMutex
hash [blake2s.Size]byte // hash value hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key chainKey [blake2s.Size]byte // chain key
presharedKey NoisePresharedKey // psk presharedKey NoiseSymmetricKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending remoteIndex uint32 // index for sending
@ -142,11 +117,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte ZeroNonce [chacha20poly1305.NonceSize]byte
) )
func mixKey(dst, c *[blake2s.Size]byte, data []byte) { func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data) KDF1(dst, c[:], data)
} }
func mixHash(dst, h *[blake2s.Size]byte, data []byte) { func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hash.Write(h[:]) hash.Write(h[:])
hash.Write(data) hash.Write(data)
@ -160,7 +135,7 @@ func (h *Handshake) Clear() {
setZero(h.chainKey[:]) setZero(h.chainKey[:])
setZero(h.hash[:]) setZero(h.hash[:])
h.localIndex = 0 h.localIndex = 0
h.state = handshakeZeroed h.state = HandshakeZeroed
} }
func (h *Handshake) mixHash(data []byte) { func (h *Handshake) mixHash(data []byte) {
@ -179,6 +154,7 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -186,7 +162,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key // create ephemeral key
var err error var err error
handshake.hash = InitialHash handshake.hash = InitialHash
handshake.chainKey = InitialChainKey handshake.chainKey = InitialChainKey
@ -195,58 +176,59 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err return nil, err
} }
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
device.aSecMux.RLock()
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
} }
device.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil { func() {
return nil, err var key [chacha20poly1305.KeySize]byte
} ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
var key [chacha20poly1305.KeySize]byte KDF2(
KDF2( &handshake.chainKey,
&handshake.chainKey, &key,
&key, handshake.chainKey[:],
handshake.chainKey[:], ss[:],
ss[:], )
) aead, _ := chacha20poly1305.New(key[:])
aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) }()
handshake.mixHash(msg.Static[:]) handshake.mixHash(msg.Static[:])
// encrypt timestamp // encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
timestamp := tai64n.Now()
aead, _ = chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
// assign index timestamp := tai64n.Now()
device.indexTable.Delete(handshake.localIndex) func() {
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) var key [chacha20poly1305.KeySize]byte
if err != nil { KDF2(
return nil, err &handshake.chainKey,
} &key,
handshake.localIndex = msg.Sender handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
handshake.mixHash(msg.Timestamp[:]) handshake.mixHash(msg.Timestamp[:])
handshake.state = handshakeInitiationCreated handshake.state = HandshakeInitiationCreated
return &msg, nil return &msg, nil
} }
@ -256,12 +238,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte chainKey [blake2s.Size]byte
) )
device.aSecMux.RLock()
if msg.Type != MessageInitiationType { if msg.Type != MessageInitiationType {
device.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock()
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -271,15 +250,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte func() {
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) var key [chacha20poly1305.KeySize]byte
if err != nil { ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
return nil KDF2(&chainKey, &key, chainKey[:], ss[:])
} aead, _ := chacha20poly1305.New(key[:])
KDF2(&chainKey, &key, chainKey[:], ss[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
aead, _ := chacha20poly1305.New(key[:]) }()
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil { if err != nil {
return nil return nil
} }
@ -288,29 +268,28 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer // lookup peer
peer := device.LookupPeer(peerPK) peer := device.LookupPeer(peerPK)
if peer == nil || !peer.isRunning.Load() { if peer == nil {
return nil return nil
} }
handshake := &peer.handshake handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity // verify identity
var timestamp tai64n.Timestamp var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock() handshake.mutex.RLock()
if isZero(handshake.precomputedStaticStatic[:]) {
handshake.mutex.RUnlock()
return nil
}
KDF2( KDF2(
&chainKey, &chainKey,
&key, &key,
chainKey[:], chainKey[:],
handshake.precomputedStaticStatic[:], handshake.precomputedStaticStatic[:],
) )
aead, _ = chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil { if err != nil {
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
@ -320,15 +299,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// protect against replay & flood // protect against replay & flood
replay := !timestamp.After(handshake.lastTimestamp) var ok bool
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate ok = timestamp.After(handshake.lastTimestamp)
ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
if replay { if !ok {
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
return nil
}
if flood {
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
return nil return nil
} }
@ -340,14 +315,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral handshake.remoteEphemeral = msg.Ephemeral
if timestamp.After(handshake.lastTimestamp) { handshake.lastTimestamp = timestamp
handshake.lastTimestamp = timestamp handshake.lastInitiationConsumption = time.Now()
} handshake.state = HandshakeInitiationConsumed
now := time.Now()
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()
@ -362,7 +332,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if handshake.state != handshakeInitiationConsumed { if handshake.state != HandshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first") return nil, errors.New("handshake initiation must be consumed first")
} }
@ -376,9 +346,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
var msg MessageResponse var msg MessageResponse
device.aSecMux.RLock()
msg.Type = MessageResponseType msg.Type = MessageResponseType
device.aSecMux.RUnlock()
msg.Sender = handshake.localIndex msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex msg.Receiver = handshake.remoteIndex
@ -392,16 +360,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) func() {
if err != nil { ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
return nil, err handshake.mixKey(ss[:])
} ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.mixKey(ss[:]) handshake.mixKey(ss[:])
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) }()
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])
// add preshared key // add preshared key
@ -418,22 +382,21 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:]) handshake.mixHash(tau[:])
aead, _ := chacha20poly1305.New(key[:]) func() {
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) aead, _ := chacha20poly1305.New(key[:])
handshake.mixHash(msg.Empty[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.mixHash(msg.Empty[:])
}()
handshake.state = handshakeResponseCreated handshake.state = HandshakeResponseCreated
return &msg, nil return &msg, nil
} }
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.aSecMux.RLock()
if msg.Type != MessageResponseType { if msg.Type != MessageResponseType {
device.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock()
// lookup handshake by receiver // lookup handshake by receiver
@ -449,12 +412,13 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
) )
ok := func() bool { ok := func() bool {
// lock handshake state // lock handshake state
handshake.mutex.RLock() handshake.mutex.RLock()
defer handshake.mutex.RUnlock() defer handshake.mutex.RUnlock()
if handshake.state != handshakeInitiationCreated { if handshake.state != HandshakeInitiationCreated {
return false return false
} }
@ -468,19 +432,17 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) func() {
if err != nil { ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
return false mixKey(&chainKey, &chainKey, ss[:])
} setZero(ss[:])
mixKey(&chainKey, &chainKey, ss[:]) }()
setZero(ss[:])
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) func() {
if err != nil { ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
return false mixKey(&chainKey, &chainKey, ss[:])
} setZero(ss[:])
mixKey(&chainKey, &chainKey, ss[:]) }()
setZero(ss[:])
// add preshared key (psk) // add preshared key (psk)
@ -498,7 +460,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript // authenticate transcript
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil { if err != nil {
return false return false
} }
@ -517,7 +479,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.hash = hash handshake.hash = hash
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.state = handshakeResponseConsumed handshake.state = HandshakeResponseConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()
@ -542,7 +504,7 @@ func (peer *Peer) BeginSymmetricSession() error {
var sendKey [chacha20poly1305.KeySize]byte var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == handshakeResponseConsumed { if handshake.state == HandshakeResponseConsumed {
KDF2( KDF2(
&sendKey, &sendKey,
&recvKey, &recvKey,
@ -550,7 +512,7 @@ func (peer *Peer) BeginSymmetricSession() error {
nil, nil,
) )
isInitiator = true isInitiator = true
} else if handshake.state == handshakeResponseCreated { } else if handshake.state == HandshakeResponseCreated {
KDF2( KDF2(
&recvKey, &recvKey,
&sendKey, &sendKey,
@ -559,7 +521,7 @@ func (peer *Peer) BeginSymmetricSession() error {
) )
isInitiator = false isInitiator = false
} else { } else {
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) return errors.New("invalid state for keypair derivation")
} }
// zero handshake // zero handshake
@ -567,7 +529,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(handshake.chainKey[:]) setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:]) setZero(handshake.localEphemeral[:])
peer.handshake.state = handshakeZeroed peer.handshake.state = HandshakeZeroed
// create AEAD instances // create AEAD instances
@ -579,7 +541,8 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(recvKey[:]) setZero(recvKey[:])
keypair.created = time.Now() keypair.created = time.Now()
keypair.replayFilter.Reset() keypair.sendNonce = 0
keypair.replayFilter.Init()
keypair.isInitiator = isInitiator keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex keypair.localIndex = peer.handshake.localIndex
keypair.remoteIndex = peer.handshake.remoteIndex keypair.remoteIndex = peer.handshake.remoteIndex
@ -596,12 +559,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock() defer keypairs.Unlock()
previous := keypairs.previous previous := keypairs.previous
next := keypairs.next.Load() next := keypairs.next
current := keypairs.current current := keypairs.current
if isInitiator { if isInitiator {
if next != nil { if next != nil {
keypairs.next.Store(nil) keypairs.next = nil
keypairs.previous = next keypairs.previous = next
device.DeleteKeypair(current) device.DeleteKeypair(current)
} else { } else {
@ -610,7 +573,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
keypairs.current = keypair keypairs.current = keypair
} else { } else {
keypairs.next.Store(keypair) keypairs.next = keypair
device.DeleteKeypair(next) device.DeleteKeypair(next)
keypairs.previous = nil keypairs.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
@ -621,19 +584,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs keypairs := &peer.keypairs
if keypairs.next != receivedKeypair {
if keypairs.next.Load() != receivedKeypair {
return false return false
} }
keypairs.Lock() keypairs.Lock()
defer keypairs.Unlock() defer keypairs.Unlock()
if keypairs.next.Load() != receivedKeypair { if keypairs.next != receivedKeypair {
return false return false
} }
old := keypairs.previous old := keypairs.previous
keypairs.previous = keypairs.current keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old) peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next.Load() keypairs.current = keypairs.next
keypairs.next.Store(nil) keypairs.next = nil
return true return true
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -9,18 +9,19 @@ import (
"crypto/subtle" "crypto/subtle"
"encoding/hex" "encoding/hex"
"errors" "errors"
"golang.org/x/crypto/chacha20poly1305"
) )
const ( const (
NoisePublicKeySize = 32 NoisePublicKeySize = 32
NoisePrivateKeySize = 32 NoisePrivateKeySize = 32
NoisePresharedKeySize = 32
) )
type ( type (
NoisePublicKey [NoisePublicKeySize]byte NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte NoisePrivateKey [NoisePrivateKeySize]byte
NoisePresharedKey [NoisePresharedKeySize]byte NoiseSymmetricKey [chacha20poly1305.KeySize]byte
NoiseNonce uint64 // padded to 12-bytes NoiseNonce uint64 // padded to 12-bytes
) )
@ -51,19 +52,18 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return return
} }
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { func (key NoisePrivateKey) ToHex() string {
err = loadExactHex(key[:], src) return hex.EncodeToString(key[:])
if key.IsZero() {
return
}
key.clamp()
return
} }
func (key *NoisePublicKey) FromHex(src string) error { func (key *NoisePublicKey) FromHex(src string) error {
return loadExactHex(key[:], src) return loadExactHex(key[:], src)
} }
func (key NoisePublicKey) ToHex() string {
return hex.EncodeToString(key[:])
}
func (key NoisePublicKey) IsZero() bool { func (key NoisePublicKey) IsZero() bool {
var zero NoisePublicKey var zero NoisePublicKey
return key.Equals(zero) return key.Equals(zero)
@ -73,6 +73,10 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
} }
func (key *NoisePresharedKey) FromHex(src string) error { func (key *NoiseSymmetricKey) FromHex(src string) error {
return loadExactHex(key[:], src) return loadExactHex(key[:], src)
} }
func (key NoiseSymmetricKey) ToHex() string {
return hex.EncodeToString(key[:])
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -9,9 +9,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"testing" "testing"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
) )
func TestCurveWrappers(t *testing.T) { func TestCurveWrappers(t *testing.T) {
@ -24,38 +21,14 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey() pk1 := sk1.publicKey()
pk2 := sk2.publicKey() pk2 := sk2.publicKey()
ss1, err1 := sk1.sharedSecret(pk2) ss1 := sk1.sharedSecret(pk2)
ss2, err2 := sk2.sharedSecret(pk1) ss2 := sk2.sharedSecret(pk1)
if ss1 != ss2 || err1 != nil || err2 != nil { if ss1 != ss2 {
t.Fatal("Failed to compute shared secet") t.Fatal("Failed to compute shared secet")
} }
} }
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := tuntest.NewChannelTUN()
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
device.SetPrivateKey(sk)
return device
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a, b []byte) {
if !bytes.Equal(a, b) {
t.Fatal(a, "!=", b)
}
}
func TestNoiseHandshake(t *testing.T) { func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t) dev1 := randDevice(t)
dev2 := randDevice(t) dev2 := randDevice(t)
@ -63,16 +36,8 @@ func TestNoiseHandshake(t *testing.T) {
defer dev1.Close() defer dev1.Close()
defer dev2.Close() defer dev2.Close()
peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
if err != nil { peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
t.Fatal(err)
}
peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
if err != nil {
t.Fatal(err)
}
peer1.Start()
peer2.Start()
assertEqual( assertEqual(
t, t,
@ -148,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err) t.Fatal("failed to derive keypair for peer 2", err)
} }
key1 := peer1.keypairs.next.Load() key1 := peer1.keypairs.next
key2 := peer2.keypairs.current key2 := peer2.keypairs.current
// encrypting / decryption test // encrypting / decryption test

View file

@ -1,35 +1,37 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"container/list" "encoding/base64"
"errors" "errors"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
)
"github.com/amnezia-vpn/amneziawg-go/conn" const (
PeerRoutineNumber = 3
) )
type Peer struct { type Peer struct {
isRunning atomic.Bool isRunning AtomicBool
keypairs Keypairs sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
handshake Handshake keypairs Keypairs
device *Device handshake Handshake
stopping sync.WaitGroup // routines pending stop device *Device
txBytes atomic.Uint64 // bytes send to peer (endpoint) endpoint Endpoint
rxBytes atomic.Uint64 // bytes received from peer persistentKeepaliveInterval uint16
lastHandshakeNano atomic.Int64 // nano seconds since epoch
endpoint struct { // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
sync.Mutex stats struct {
val conn.Endpoint txBytes uint64 // bytes send to peer (endpoint)
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission rxBytes uint64 // bytes received from peer
disableRoaming bool lastHandshakeNano int64 // nano seconds since epoch
} }
timers struct { timers struct {
@ -38,32 +40,40 @@ type Peer struct {
newHandshake *Timer newHandshake *Timer
zeroKeyMaterial *Timer zeroKeyMaterial *Timer
persistentKeepalive *Timer persistentKeepalive *Timer
handshakeAttempts atomic.Uint32 handshakeAttempts uint32
needAnotherKeepalive atomic.Bool needAnotherKeepalive AtomicBool
sentLastMinuteHandshake atomic.Bool sentLastMinuteHandshake AtomicBool
} }
state struct { signals struct {
sync.Mutex // protects against concurrent Start/Stop newKeypairArrived chan struct{}
flushNonceQueue chan struct{}
} }
queue struct { queue struct {
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission outbound chan *QueueOutboundElement // sequential ordering of work
inbound *autodrainingInboundQueue // sequential ordering of tun writing inbound chan *QueueInboundElement // sequential ordering of work
packetInNonceQueueIsAwaitingKey AtomicBool
} }
cookieGenerator CookieGenerator routines struct {
trieEntries list.List sync.Mutex // held when stopping / starting routines
persistentKeepaliveInterval atomic.Uint32 starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop
stop chan struct{} // size 0, stop all go routines in peer
}
cookieGenerator CookieGenerator
} }
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if device.isClosed() { if device.isClosed.Get() {
return nil, errors.New("device closed") return nil, errors.New("device closed")
} }
// lock resources // lock resources
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -71,144 +81,136 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
defer device.peers.Unlock() defer device.peers.Unlock()
// check if over limit // check if over limit
if len(device.peers.keyMap) >= MaxPeers { if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("too many peers") return nil, errors.New("too many peers")
} }
// create peer // create peer
peer := new(Peer) peer := new(Peer)
peer.Lock()
defer peer.Unlock()
peer.cookieGenerator.Init(pk) peer.cookieGenerator.Init(pk)
peer.device = device peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.isRunning.Set(false)
peer.queue.inbound = newAutodrainingInboundQueue(device)
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key // map public key
_, ok := device.peers.keyMap[pk] _, ok := device.peers.keyMap[pk]
if ok { if ok {
return nil, errors.New("adding existing peer") return nil, errors.New("adding existing peer")
} }
// pre-compute DH // pre-compute DH
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint // reset endpoint
peer.endpoint.Lock()
peer.endpoint.val = nil
peer.endpoint.disableRoaming = false
peer.endpoint.clearSrcOnTx = false
peer.endpoint.Unlock()
// init timers peer.endpoint = nil
peer.timersInit()
// add // conditionally add
device.peers.keyMap[pk] = peer
if !ssIsZero {
device.peers.keyMap[pk] = peer
} else {
return nil, nil
}
// start peer
if peer.device.isUp.Get() {
peer.Start()
}
return peer, nil return peer, nil
} }
func (peer *Peer) SendBuffers(buffers [][]byte) error { func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.RLock() peer.device.net.RLock()
defer peer.device.net.RUnlock() defer peer.device.net.RUnlock()
if peer.device.isClosed() { if peer.device.net.bind == nil {
return nil return errors.New("no bind")
} }
peer.endpoint.Lock() peer.RLock()
endpoint := peer.endpoint.val defer peer.RUnlock()
if endpoint == nil {
peer.endpoint.Unlock() if peer.endpoint == nil {
return errors.New("no known endpoint for peer") return errors.New("no known endpoint for peer")
} }
if peer.endpoint.clearSrcOnTx {
endpoint.ClearSrc()
peer.endpoint.clearSrcOnTx = false
}
peer.endpoint.Unlock()
err := peer.device.net.bind.Send(buffers, endpoint) err := peer.device.net.bind.Send(buffer, peer.endpoint)
if err == nil { if err == nil {
var totalLen uint64 atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
for _, b := range buffers {
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)
} }
return err return err
} }
func (peer *Peer) String() string { func (peer *Peer) String() string {
// The awful goo that follows is identical to: base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
// abbreviatedKey := "invalid"
// base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) if len(base64Key) == 44 {
// abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43] abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
//
// except that it is considerably more efficient.
src := peer.handshake.remoteStatic
b64 := func(input byte) byte {
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
} }
b := []byte("peer(____…____)") return fmt.Sprintf("peer(%s)", abbreviatedKey)
const first = len("peer(")
const second = len("peer(____…")
b[first+0] = b64((src[0] >> 2) & 63)
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
b[first+3] = b64(src[2] & 63)
b[second+0] = b64(src[29] & 63)
b[second+1] = b64((src[30] >> 2) & 63)
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
b[second+3] = b64((src[31] << 2) & 63)
return string(b)
} }
func (peer *Peer) Start() { func (peer *Peer) Start() {
// should never start a peer on a closed device // should never start a peer on a closed device
if peer.device.isClosed() {
if peer.device.isClosed.Get() {
return return
} }
// prevent simultaneous start/stop operations // prevent simultaneous start/stop operations
peer.state.Lock()
defer peer.state.Unlock()
if peer.isRunning.Load() { peer.routines.Lock()
defer peer.routines.Unlock()
if peer.isRunning.Get() {
return return
} }
device := peer.device device := peer.device
device.log.Verbosef("%v - Starting", peer) device.log.Debug.Println(peer, "- Starting...")
// reset routine state // reset routine state
peer.stopping.Wait()
peer.stopping.Add(2)
peer.handshake.mutex.Lock() peer.routines.starting.Wait()
peer.routines.stopping.Wait()
peer.routines.stop = make(chan struct{})
peer.routines.starting.Add(PeerRoutineNumber)
peer.routines.stopping.Add(PeerRoutineNumber)
// prepare queues
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
peer.timersInit()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
peer.handshake.mutex.Unlock() peer.signals.newKeypairArrived = make(chan struct{}, 1)
peer.signals.flushNonceQueue = make(chan struct{}, 1)
peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes // wait for routines to start
peer.timersStart() go peer.RoutineNonce()
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
device.flushInboundQueue(peer.queue.inbound) peer.routines.starting.Wait()
device.flushOutboundQueue(peer.queue.outbound) peer.isRunning.Set(true)
// Use the device batch size, not the bind batch size, as the device size is
// the size of the batch pools.
batchSize := peer.device.BatchSize()
go peer.RoutineSequentialSender(batchSize)
go peer.RoutineSequentialReceiver(batchSize)
peer.isRunning.Store(true)
} }
func (peer *Peer) ZeroAndFlushAll() { func (peer *Peer) ZeroAndFlushAll() {
@ -220,10 +222,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock() keypairs.Lock()
device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current) device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next.Load()) device.DeleteKeypair(keypairs.next)
keypairs.previous = nil keypairs.previous = nil
keypairs.current = nil keypairs.current = nil
keypairs.next.Store(nil) keypairs.next = nil
keypairs.Unlock() keypairs.Unlock()
// clear handshake state // clear handshake state
@ -234,7 +236,7 @@ func (peer *Peer) ZeroAndFlushAll() {
handshake.Clear() handshake.Clear()
handshake.mutex.Unlock() handshake.mutex.Unlock()
peer.FlushStagedPackets() peer.FlushNonceQueue()
} }
func (peer *Peer) ExpireCurrentKeypairs() { func (peer *Peer) ExpireCurrentKeypairs() {
@ -242,55 +244,58 @@ func (peer *Peer) ExpireCurrentKeypairs() {
handshake.mutex.Lock() handshake.mutex.Lock()
peer.device.indexTable.Delete(handshake.localIndex) peer.device.indexTable.Delete(handshake.localIndex)
handshake.Clear() handshake.Clear()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
handshake.mutex.Unlock() handshake.mutex.Unlock()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
keypairs := &peer.keypairs keypairs := &peer.keypairs
keypairs.Lock() keypairs.Lock()
if keypairs.current != nil { if keypairs.current != nil {
keypairs.current.sendNonce.Store(RejectAfterMessages) keypairs.current.sendNonce = RejectAfterMessages
} }
if next := keypairs.next.Load(); next != nil { if keypairs.next != nil {
next.sendNonce.Store(RejectAfterMessages) keypairs.next.sendNonce = RejectAfterMessages
} }
keypairs.Unlock() keypairs.Unlock()
} }
func (peer *Peer) Stop() { func (peer *Peer) Stop() {
peer.state.Lock()
defer peer.state.Unlock() // prevent simultaneous start/stop operations
if !peer.isRunning.Swap(false) { if !peer.isRunning.Swap(false) {
return return
} }
peer.device.log.Verbosef("%v - Stopping", peer) peer.routines.starting.Wait()
peer.routines.Lock()
defer peer.routines.Unlock()
peer.device.log.Debug.Println(peer, "- Stopping...")
peer.timersStop() peer.timersStop()
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
peer.queue.inbound.c <- nil // stop & wait for ongoing peer routines
peer.queue.outbound.c <- nil
peer.stopping.Wait() close(peer.routines.stop)
peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us peer.routines.stopping.Wait()
// close queues
close(peer.queue.nonce)
close(peer.queue.outbound)
close(peer.queue.inbound)
peer.ZeroAndFlushAll() peer.ZeroAndFlushAll()
} }
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { var RoamingDisabled bool
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
if peer.endpoint.disableRoaming {
return
}
peer.endpoint.clearSrcOnTx = false
peer.endpoint.val = endpoint
}
func (peer *Peer) markEndpointSrcForClearing() { func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
peer.endpoint.Lock() if RoamingDisabled {
defer peer.endpoint.Unlock()
if peer.endpoint.val == nil {
return return
} }
peer.endpoint.clearSrcOnTx = true peer.Lock()
peer.endpoint = endpoint
peer.Unlock()
} }

View file

@ -1,120 +1,89 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import "sync"
"sync"
"sync/atomic"
)
type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count atomic.Uint32
max uint32
}
func NewWaitPool(max uint32, new func() any) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock}
return p
}
func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
for p.count.Load() >= p.max {
p.cond.Wait()
}
p.count.Add(1)
p.lock.Unlock()
}
return p.pool.Get()
}
func (p *WaitPool) Put(x any) {
p.pool.Put(x)
if p.max == 0 {
return
}
p.count.Add(^uint32(0))
p.cond.Signal()
}
func (device *Device) PopulatePools() { func (device *Device) PopulatePools() {
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { if PreallocatedBuffersPerPool == 0 {
s := make([]*QueueInboundElement, 0, device.BatchSize()) device.pool.messageBufferPool = &sync.Pool{
return &QueueInboundElementsContainer{elems: s} New: func() interface{} {
}) return new([MaxMessageSize]byte)
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { },
s := make([]*QueueOutboundElement, 0, device.BatchSize()) }
return &QueueOutboundElementsContainer{elems: s} device.pool.inboundElementPool = &sync.Pool{
}) New: func() interface{} {
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueInboundElement)
return new([MaxMessageSize]byte) },
}) }
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { device.pool.outboundElementPool = &sync.Pool{
return new(QueueInboundElement) New: func() interface{} {
}) return new(QueueOutboundElement)
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { },
return new(QueueOutboundElement) }
}) } else {
} device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) }
c.Mutex = sync.Mutex{} device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
return c for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
} device.pool.inboundElementReuseChan <- new(QueueInboundElement)
}
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
for i := range c.elems { for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
c.elems[i] = nil device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
}
} }
c.elems = c.elems[:0]
device.pool.inboundElementsContainer.Put(c)
}
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
c.elems = c.elems[:0]
device.pool.outboundElementsContainer.Put(c)
} }
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) if PreallocatedBuffersPerPool == 0 {
return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
} else {
return <-device.pool.messageBufferReuseChan
}
} }
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
device.pool.messageBuffers.Put(msg) if PreallocatedBuffersPerPool == 0 {
device.pool.messageBufferPool.Put(msg)
} else {
device.pool.messageBufferReuseChan <- msg
}
} }
func (device *Device) GetInboundElement() *QueueInboundElement { func (device *Device) GetInboundElement() *QueueInboundElement {
return device.pool.inboundElements.Get().(*QueueInboundElement) if PreallocatedBuffersPerPool == 0 {
return device.pool.inboundElementPool.Get().(*QueueInboundElement)
} else {
return <-device.pool.inboundElementReuseChan
}
} }
func (device *Device) PutInboundElement(elem *QueueInboundElement) { func (device *Device) PutInboundElement(msg *QueueInboundElement) {
elem.clearPointers() if PreallocatedBuffersPerPool == 0 {
device.pool.inboundElements.Put(elem) device.pool.inboundElementPool.Put(msg)
} else {
device.pool.inboundElementReuseChan <- msg
}
} }
func (device *Device) GetOutboundElement() *QueueOutboundElement { func (device *Device) GetOutboundElement() *QueueOutboundElement {
return device.pool.outboundElements.Get().(*QueueOutboundElement) if PreallocatedBuffersPerPool == 0 {
return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
} else {
return <-device.pool.outboundElementReuseChan
}
} }
func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
elem.clearPointers() if PreallocatedBuffersPerPool == 0 {
device.pool.outboundElements.Put(elem) device.pool.outboundElementPool.Put(msg)
} else {
device.pool.outboundElementReuseChan <- msg
}
} }

View file

@ -1,139 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled")
var wg sync.WaitGroup
var trials atomic.Int32
startTrials := int32(100000)
if raceEnabled {
// This test can be very slow with -race.
startTrials /= 10
}
trials.Store(startTrials)
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
t.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
var max atomic.Uint32
updateMax := func() {
count := p.count.Load()
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}
for {
old := max.Load()
if count <= old {
break
}
if max.CompareAndSwap(old, count) {
break
}
}
}
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
updateMax()
x := p.Get()
updateMax()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
updateMax()
p.Put(x)
updateMax()
}
}()
}
wg.Wait()
if max.Load() != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
}
}
func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkWaitPoolEmpty(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(0, func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkSyncPool(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := sync.Pool{New: func() any { return make([]byte, 16) }}
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}

View file

@ -1,19 +1,16 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import "github.com/amnezia-vpn/amneziawg-go/conn"
/* Reduce memory consumption for Android */ /* Reduce memory consumption for Android */
const ( const (
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024 QueueOutboundSize = 1024
QueueInboundSize = 1024 QueueInboundSize = 1024
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram MaxSegmentSize = 2200
PreallocatedBuffersPerPool = 4096 PreallocatedBuffersPerPool = 4096
) )

View file

@ -1,16 +1,13 @@
//go:build !android && !ios && !windows // +build !android,!ios
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import "github.com/amnezia-vpn/amneziawg-go/conn"
const ( const (
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024 QueueOutboundSize = 1024
QueueInboundSize = 1024 QueueInboundSize = 1024
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024

View file

@ -1,21 +1,18 @@
//go:build ios // +build ios
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
// Fit within memory limits for iOS's Network Extension API, which has stricter requirements. /* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */
// These are vars instead of consts, because heavier network extensions might want to reduce
// them further.
var (
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
PreallocatedBuffersPerPool uint32 = 1024
)
const MaxSegmentSize = 1700 const (
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 1700
PreallocatedBuffersPerPool = 1024
)

View file

@ -1,15 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
const (
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
)

View file

@ -1,10 +0,0 @@
//go:build !race
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
const raceEnabled = false

View file

@ -1,10 +0,0 @@
//go:build race
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
const raceEnabled = true

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -8,12 +8,12 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"net" "net"
"strconv"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@ -22,32 +22,52 @@ import (
type QueueHandshakeElement struct { type QueueHandshakeElement struct {
msgType uint32 msgType uint32
packet []byte packet []byte
endpoint conn.Endpoint endpoint Endpoint
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
} }
type QueueInboundElement struct { type QueueInboundElement struct {
dropped int32
sync.Mutex
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
packet []byte packet []byte
counter uint64 counter uint64
keypair *Keypair keypair *Keypair
endpoint conn.Endpoint endpoint Endpoint
} }
type QueueInboundElementsContainer struct { func (elem *QueueInboundElement) Drop() {
sync.Mutex atomic.StoreInt32(&elem.dropped, AtomicTrue)
elems []*QueueInboundElement
} }
// clearPointers clears elem fields that contain pointers. func (elem *QueueInboundElement) IsDropped() bool {
// This makes the garbage collector's life easier and return atomic.LoadInt32(&elem.dropped) == AtomicTrue
// avoids accidentally keeping other objects around unnecessarily. }
// It also reduces the possible collateral damage from use-after-free bugs.
func (elem *QueueInboundElement) clearPointers() { func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
elem.buffer = nil select {
elem.packet = nil case inboundQueue <- element:
elem.keypair = nil select {
elem.endpoint = nil case decryptionQueue <- element:
return true
default:
element.Drop()
element.Unlock()
return false
}
default:
device.PutInboundElement(element)
return false
}
}
func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
select {
case queue <- element:
return true
default:
return false
}
} }
/* Called when a new authenticated message has been received /* Called when a new authenticated message has been received
@ -55,12 +75,12 @@ func (elem *QueueInboundElement) clearPointers() {
* NOTE: Not thread safe, but called by sequential receiver! * NOTE: Not thread safe, but called by sequential receiver!
*/ */
func (peer *Peer) keepKeyFreshReceiving() { func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake.Load() { if peer.timers.sentLastMinuteHandshake.Get() {
return return
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Store(true) peer.timers.sentLastMinuteHandshake.Set(true)
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
@ -70,215 +90,188 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for * Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately) * IPv4 and IPv6 (separately)
*/ */
func (device *Device) RoutineReceiveIncoming( func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
maxBatchSize int,
recv conn.ReceiveFunc, logDebug := device.log.Debug
) {
recvName := recv.PrettyName()
defer func() { defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.net.stopping.Done() device.net.stopping.Done()
}() }()
device.log.Verbosef("Routine: receive incoming %s - started", recvName) logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started")
device.net.starting.Done()
// receive datagrams until conn is closed // receive datagrams until conn is closed
buffer := device.GetMessageBuffer()
var ( var (
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) err error
bufs = make([][]byte, maxBatchSize) size int
err error endpoint Endpoint
sizes = make([]int, maxBatchSize)
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
) )
for i := range bufsArrs {
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
}
defer func() {
for i := 0; i < maxBatchSize; i++ {
if bufsArrs[i] != nil {
device.PutMessageBuffer(bufsArrs[i])
}
}
}()
for { for {
count, err = recv(bufs, sizes, endpoints)
// read next datagram
switch IP {
case ipv4.Version:
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
panic("invalid IP version")
}
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { device.PutMessageBuffer(buffer)
return
}
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
return
}
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
continue
}
return return
} }
deathSpiral = 0
device.aSecMux.RLock() if size < MinMessageSize {
// handle each packet in the batch continue
for i, size := range sizes[:count] {
if size < MinMessageSize {
continue
}
// check size of packet
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAdvancedSecurityOn() {
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
// making sure we have the right msgType
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
device.log.Verbosef("Transport packet lined up with another msg type")
msgType = binary.LittleEndian.Uint32(packet[:4])
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
if msgType != MessageTransportType {
device.log.Verbosef("ASec: Received message with unknown type")
continue
}
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
}
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = bufsArrs[i]
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetInboundElementsContainer()
elemsForPeer.Lock()
elemsByPeer[peer] = elemsForPeer
}
elemsForPeer.elems = append(elemsForPeer.elems, elem)
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType:
if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType:
if len(packet) != MessageCookieReplySize {
continue
}
default:
device.log.Verbosef("Received message with unknown type")
continue
}
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
buffer: bufsArrs[i],
packet: packet,
endpoint: endpoints[i],
}:
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
default:
}
} }
device.aSecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer { // check size of packet
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer packet := buffer[:size]
device.queue.decryption.c <- elemsContainer msgType := binary.LittleEndian.Uint32(packet[:4])
} else {
for _, elem := range elemsContainer.elems { var okay bool
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem) switch msgType {
}
device.PutInboundElementsContainer(elemsContainer) // check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = buffer
elem.keypair = keypair
elem.dropped = AtomicFalse
elem.endpoint = endpoint
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
// add to decryption queues
if peer.isRunning.Get() {
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
buffer = device.GetMessageBuffer()
}
}
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType:
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
default:
logDebug.Println("Received message with unknown type")
}
if okay {
if (device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
endpoint: endpoint,
},
)) {
buffer = device.GetMessageBuffer()
} }
delete(elemsByPeer, peer)
} }
} }
} }
func (device *Device) RoutineDecryption(id int) { func (device *Device) RoutineDecryption() {
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) logDebug := device.log.Debug
device.log.Verbosef("Routine: decryption worker %d - started", id) defer func() {
logDebug.Println("Routine: decryption worker - stopped")
device.state.stopping.Done()
}()
logDebug.Println("Routine: decryption worker - started")
device.state.starting.Done()
for {
select {
case <-device.signals.stop:
return
case elem, ok := <-device.queue.decryption:
if !ok {
return
}
// check if dropped
if elem.IsDropped() {
continue
}
for elemsContainer := range device.queue.decryption.c {
for _, elem := range elemsContainer.elems {
// split message into fields // split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:] content := elem.packet[MessageTransportOffsetContent:]
// expand nonce
nonce[0x4] = counter[0x0]
nonce[0x5] = counter[0x1]
nonce[0x6] = counter[0x2]
nonce[0x7] = counter[0x3]
nonce[0x8] = counter[0x4]
nonce[0x9] = counter[0x5]
nonce[0xa] = counter[0x6]
nonce[0xb] = counter[0x7]
// decrypt and release to consumer // decrypt and release to consumer
var err error var err error
elem.counter = binary.LittleEndian.Uint64(counter) elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open( elem.packet, err = elem.keypair.receive.Open(
content[:0], content[:0],
nonce[:], nonce[:],
@ -286,25 +279,51 @@ func (device *Device) RoutineDecryption(id int) {
nil, nil,
) )
if err != nil { if err != nil {
elem.packet = nil elem.Drop()
device.PutMessageBuffer(elem.buffer)
} }
elem.Unlock()
} }
elemsContainer.Unlock()
} }
} }
/* Handles incoming packets related to handshake /* Handles incoming packets related to handshake
*/ */
func (device *Device) RoutineHandshake(id int) { func (device *Device) RoutineHandshake() {
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
var elem QueueHandshakeElement
var ok bool
defer func() { defer func() {
device.log.Verbosef("Routine: handshake worker %d - stopped", id) logDebug.Println("Routine: handshake worker - stopped")
device.queue.encryption.wg.Done() device.state.stopping.Done()
if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
}
}() }()
device.log.Verbosef("Routine: handshake worker %d - started", id)
for elem := range device.queue.handshake.c { logDebug.Println("Routine: handshake worker - started")
device.state.starting.Done()
device.aSecMux.RLock() for {
if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
elem.buffer = nil
}
select {
case elem, ok = <-device.queue.handshake:
case <-device.signals.stop:
return
}
if !ok {
return
}
// handle cookie fields and ratelimiting // handle cookie fields and ratelimiting
@ -318,8 +337,8 @@ func (device *Device) RoutineHandshake(id int) {
reader := bytes.NewReader(elem.packet) reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply) err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil { if err != nil {
device.log.Verbosef("Failed to decode cookie reply") logDebug.Println("Failed to decode cookie reply")
goto skip return
} }
// lookup peer from index // lookup peer from index
@ -327,32 +346,27 @@ func (device *Device) RoutineHandshake(id int) {
entry := device.indexTable.Lookup(reply.Receiver) entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil { if entry.peer == nil {
goto skip continue
} }
// consume reply // consume reply
if peer := entry.peer; peer.isRunning.Load() { if peer := entry.peer; peer.isRunning.Get() {
device.log.Verbosef( logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString())
"Receiving cookie response from %s",
elem.endpoint.DstToString(),
)
if !peer.cookieGenerator.ConsumeReply(&reply) { if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef( logDebug.Println("Could not decrypt invalid cookie response")
"Could not decrypt invalid cookie response",
)
} }
} }
goto skip continue
case MessageInitiationType, MessageResponseType: case MessageInitiationType, MessageResponseType:
// check mac fields and maybe ratelimit // check mac fields and maybe ratelimit
if !device.cookieChecker.CheckMAC1(elem.packet) { if !device.cookieChecker.CheckMAC1(elem.packet) {
device.log.Verbosef("Received packet with invalid mac1") logDebug.Println("Received packet with invalid mac1")
goto skip continue
} }
// endpoints destination address is the source of the datagram // endpoints destination address is the source of the datagram
@ -363,39 +377,45 @@ func (device *Device) RoutineHandshake(id int) {
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem) device.SendHandshakeCookie(&elem)
goto skip continue
} }
// check ratelimiter // check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
goto skip continue
} }
} }
default: default:
device.log.Errorf("Invalid packet ended up in the handshake queue") logError.Println("Invalid packet ended up in the handshake queue")
goto skip continue
} }
// handle handshake initiation/response content // handle handshake initiation/response content
switch elem.msgType { switch elem.msgType {
case MessageInitiationType: case MessageInitiationType:
// unmarshal // unmarshal
var msg MessageInitiation var msg MessageInitiation
reader := bytes.NewReader(elem.packet) reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg) err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil { if err != nil {
device.log.Errorf("Failed to decode initiation message") logError.Println("Failed to decode initiation message")
goto skip continue
} }
// consume initiation // consume initiation
peer := device.ConsumeMessageInitiation(&msg) peer := device.ConsumeMessageInitiation(&msg)
if peer == nil { if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) logInfo.Println(
goto skip "Received invalid initiation message from",
elem.endpoint.DstToString(),
)
continue
} }
// update timers // update timers
@ -406,8 +426,8 @@ func (device *Device) RoutineHandshake(id int) {
// update endpoint // update endpoint
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer) logDebug.Println(peer, "- Received handshake initiation")
peer.rxBytes.Add(uint64(len(elem.packet))) atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
peer.SendHandshakeResponse() peer.SendHandshakeResponse()
@ -419,23 +439,26 @@ func (device *Device) RoutineHandshake(id int) {
reader := bytes.NewReader(elem.packet) reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg) err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil { if err != nil {
device.log.Errorf("Failed to decode response message") logError.Println("Failed to decode response message")
goto skip continue
} }
// consume response // consume response
peer := device.ConsumeMessageResponse(&msg) peer := device.ConsumeMessageResponse(&msg)
if peer == nil { if peer == nil {
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) logInfo.Println(
goto skip "Received invalid response message from",
elem.endpoint.DstToString(),
)
continue
} }
// update endpoint // update endpoint
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer) logDebug.Println(peer, "- Received handshake response")
peer.rxBytes.Add(uint64(len(elem.packet))) atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
// update timers // update timers
@ -447,131 +470,178 @@ func (device *Device) RoutineHandshake(id int) {
err = peer.BeginSymmetricSession() err = peer.BeginSymmetricSession()
if err != nil { if err != nil {
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) logError.Println(peer, "- Failed to derive keypair:", err)
goto skip continue
} }
peer.timersSessionDerived() peer.timersSessionDerived()
peer.timersHandshakeComplete() peer.timersHandshakeComplete()
peer.SendKeepalive() peer.SendKeepalive()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default:
}
} }
skip:
device.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
} }
} }
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device device := peer.device
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
var elem *QueueInboundElement
defer func() { defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) logDebug.Println(peer, "- Routine: sequential receiver - stopped")
peer.stopping.Done() peer.routines.stopping.Done()
}() if elem != nil {
device.log.Verbosef("%v - Routine: sequential receiver - started", peer) if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
bufs := make([][]byte, 0, maxBatchSize) }
device.PutInboundElement(elem)
for elemsContainer := range peer.queue.inbound.c {
if elemsContainer == nil {
return
} }
elemsContainer.Lock() }()
validTailPacket := -1
dataPacketReceived := false logDebug.Println(peer, "- Routine: sequential receiver - started")
rxBytesLen := uint64(0)
for i, elem := range elemsContainer.elems { peer.routines.starting.Done()
if elem.packet == nil {
// decryption failed for {
continue if elem != nil {
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
} }
device.PutInboundElement(elem)
elem = nil
}
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { var elemOk bool
continue select {
case <-peer.routines.stop:
return
case elem, elemOk = <-peer.queue.inbound:
if !elemOk {
return
} }
}
validTailPacket = i // wait for decryption
if peer.ReceivedWithKeypair(elem.keypair) {
peer.SetEndpointFromPacket(elem.endpoint)
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
if len(elem.packet) == 0 { elem.Lock()
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
dataPacketReceived = true
switch elem.packet[0] >> 4 { if elem.IsDropped() {
case 4: continue
if len(elem.packet) < ipv4.HeaderLen { }
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
continue
}
case 6: // check for replay
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
continue
}
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
// check if using new keypair
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default: default:
device.log.Verbosef( }
"Packet with invalid IP version from %v", }
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
// check for keepalive
if len(elem.packet) == 0 {
logDebug.Println(peer, "- Receiving keepalive packet")
continue
}
peer.timersDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer {
logInfo.Println(
"IPv4 packet with disallowed source address from",
peer, peer,
) )
continue continue
} }
bufs = append( case ipv6.Version:
bufs,
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], // strip padding
)
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer {
logInfo.Println(
"IPv6 packet with disallowed source address from",
peer,
)
continue
}
default:
logInfo.Println("Packet with invalid IP version from", peer)
continue
} }
peer.rxBytes.Add(rxBytesLen) // write to tun device
if validTailPacket >= 0 {
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) offset := MessageTransportOffsetContent
peer.keepKeyFreshReceiving() _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
peer.timersAnyAuthenticatedPacketTraversal() if len(peer.queue.inbound) == 0 {
peer.timersAnyAuthenticatedPacketReceived() err = device.tun.device.Flush()
} if err != nil {
if dataPacketReceived { peer.device.log.Error.Printf("Unable to flush packets: %v", err)
peer.timersDataReceived()
}
if len(bufs) > 0 {
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
} }
} }
for _, elem := range elemsContainer.elems { if err != nil && !device.isClosed.Get() {
device.PutMessageBuffer(elem.buffer) logError.Println("Failed to write packet to TUN device:", err)
device.PutInboundElement(elem)
} }
bufs = bufs[:0]
device.PutInboundElementsContainer(elemsContainer)
} }
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@ -8,14 +8,11 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"net" "net"
"os"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@ -46,6 +43,8 @@ import (
*/ */
type QueueOutboundElement struct { type QueueOutboundElement struct {
dropped int32
sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!) packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption nonce uint64 // nonce for encryption
@ -53,52 +52,80 @@ type QueueOutboundElement struct {
peer *Peer // related peer peer *Peer // related peer
} }
type QueueOutboundElementsContainer struct {
sync.Mutex
elems []*QueueOutboundElement
}
func (device *Device) NewOutboundElement() *QueueOutboundElement { func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement() elem := device.GetOutboundElement()
elem.dropped = AtomicFalse
elem.buffer = device.GetMessageBuffer() elem.buffer = device.GetMessageBuffer()
elem.Mutex = sync.Mutex{}
elem.nonce = 0 elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers. elem.keypair = nil
elem.peer = nil
return elem return elem
} }
// clearPointers clears elem fields that contain pointers. func (elem *QueueOutboundElement) Drop() {
// This makes the garbage collector's life easier and atomic.StoreInt32(&elem.dropped, AtomicTrue)
// avoids accidentally keeping other objects around unnecessarily. }
// It also reduces the possible collateral damage from use-after-free bugs.
func (elem *QueueOutboundElement) clearPointers() { func (elem *QueueOutboundElement) IsDropped() bool {
elem.buffer = nil return atomic.LoadInt32(&elem.dropped) == AtomicTrue
elem.packet = nil }
elem.keypair = nil
elem.peer = nil func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
for {
select {
case queue <- element:
return
default:
select {
case old := <-queue:
device.PutMessageBuffer(old.buffer)
device.PutOutboundElement(old)
default:
}
}
}
}
func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
select {
case outboundQueue <- element:
select {
case encryptionQueue <- element:
return
default:
element.Drop()
element.peer.device.PutMessageBuffer(element.buffer)
element.Unlock()
}
default:
element.peer.device.PutMessageBuffer(element.buffer)
element.peer.device.PutOutboundElement(element)
}
} }
/* Queues a keepalive if no packets are queued for peer /* Queues a keepalive if no packets are queued for peer
*/ */
func (peer *Peer) SendKeepalive() { func (peer *Peer) SendKeepalive() bool {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() { if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
elem := peer.device.NewOutboundElement() return false
elemsContainer := peer.device.GetOutboundElementsContainer() }
elemsContainer.elems = append(elemsContainer.elems, elem) elem := peer.device.NewOutboundElement()
select { elem.packet = nil
case peer.queue.staged <- elemsContainer: select {
peer.device.log.Verbosef("%v - Sending keepalive packet", peer) case peer.queue.nonce <- elem:
default: peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
peer.device.PutMessageBuffer(elem.buffer) return true
peer.device.PutOutboundElement(elem) default:
peer.device.PutOutboundElementsContainer(elemsContainer) peer.device.PutMessageBuffer(elem.buffer)
} peer.device.PutOutboundElement(elem)
return false
} }
peer.SendStagedPackets()
} }
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry { if !isRetry {
peer.timers.handshakeAttempts.Store(0) atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
} }
peer.handshake.mutex.RLock() peer.handshake.mutex.RLock()
@ -116,65 +143,26 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.handshake.lastSentHandshake = time.Now() peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
peer.device.log.Verbosef("%v - Sending handshake initiation", peer) peer.device.log.Debug.Println(peer, "- Sending handshake initiation")
msg, err := peer.device.CreateMessageInitiation(peer) msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err)
return err return err
} }
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.device.junkCreator.createJunkPackets()
peer.device.aSecMux.RUnlock()
if err != nil { var buff [MessageInitiationSize]byte
peer.device.log.Errorf("%v - %v", peer, err) writer := bytes.NewBuffer(buff[:0])
return err
}
if len(junks) > 0 {
err = peer.SendBuffers(junks)
if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
return err
}
}
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg) binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes() packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet) peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
sendBuffer = append(sendBuffer, junkedHeader) err = peer.SendBuffer(packet)
err = peer.SendBuffers(sendBuffer)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err)
} }
peer.timersHandshakeInitiated() peer.timersHandshakeInitiated()
@ -186,40 +174,23 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.handshake.lastSentHandshake = time.Now() peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
peer.device.log.Verbosef("%v - Sending handshake response", peer) peer.device.log.Debug.Println(peer, "- Sending handshake response")
response, err := peer.device.CreateMessageResponse(peer) response, err := peer.device.CreateMessageResponse(peer)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) peer.device.log.Error.Println(peer, "- Failed to create response message:", err)
return err return err
} }
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
var buff [MessageResponseSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, response) binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes() packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet) peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
err = peer.BeginSymmetricSession() err = peer.BeginSymmetricSession()
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err)
return err return err
} }
@ -227,36 +198,32 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
// TODO: allocation could be avoided err = peer.SendBuffer(packet)
err = peer.SendBuffers([][]byte{junkedHeader})
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) peer.device.log.Error.Println(peer, "- Failed to send handshake response", err)
} }
return err return err
} }
func (device *Device) SendHandshakeCookie( func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
initiatingElem *QueueHandshakeElement,
) error { device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString())
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply( reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
initiatingElem.packet,
sender,
initiatingElem.endpoint.DstToBytes(),
)
if err != nil { if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err) device.log.Error.Println("Failed to create cookie reply:", err)
return err return err
} }
var buf [MessageCookieReplySize]byte var buff [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
// TODO: allocation could be avoided device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) if err != nil {
return nil device.log.Error.Println("Failed to send cookie reply:", err)
}
return err
} }
func (peer *Peer) keepKeyFreshSending() { func (peer *Peer) keepKeyFreshSending() {
@ -264,255 +231,280 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil { if keypair == nil {
return return
} }
nonce := keypair.sendNonce.Load() nonce := atomic.LoadUint64(&keypair.sendNonce)
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN() { func (device *Device) RoutineReadFromTUN() {
logDebug := device.log.Debug
logError := device.log.Error
defer func() { defer func() {
device.log.Verbosef("Routine: TUN reader - stopped") logDebug.Println("Routine: TUN reader - stopped")
device.state.stopping.Done() device.state.stopping.Done()
device.queue.encryption.wg.Done()
}() }()
device.log.Verbosef("Routine: TUN reader - started") logDebug.Println("Routine: TUN reader - started")
device.state.starting.Done()
var ( var elem *QueueOutboundElement
batchSize = device.BatchSize()
readErr error
elems = make([]*QueueOutboundElement, batchSize)
bufs = make([][]byte, batchSize)
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
count = 0
sizes = make([]int, batchSize)
offset = MessageTransportHeaderSize
)
for i := range elems { for {
elems[i] = device.NewOutboundElement() if elem != nil {
bufs[i] = elems[i].buffer[:] device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
elem = device.NewOutboundElement()
// read packet
offset := MessageTransportHeaderSize
size, err := device.tun.device.Read(elem.buffer[:], offset)
if err != nil {
if !device.isClosed.Get() {
logError.Println("Failed to read packet from TUN device:", err)
device.Close()
}
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return
}
if size == 0 || size > MaxContentSize {
continue
}
elem.packet = elem.buffer[offset : offset+size]
// lookup peer
var peer *Peer
switch elem.packet[0] >> 4 {
case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.LookupIPv4(dst)
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.LookupIPv6(dst)
default:
logDebug.Println("Received packet with unknown IP version")
}
if peer == nil {
continue
}
// insert into nonce/pre-handshake queue
if peer.isRunning.Get() {
if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
peer.SendHandshakeInitiation(false)
}
addToNonceQueue(peer.queue.nonce, elem, device)
elem = nil
}
} }
}
defer func() { func (peer *Peer) FlushNonceQueue() {
for _, elem := range elems { select {
if elem != nil { case peer.signals.flushNonceQueue <- struct{}{}:
default:
}
}
/* Queues packets when there is no handshake.
* Then assigns nonces to packets sequentially
* and creates "work" structs for workers
*
* Obs. A single instance per peer
*/
func (peer *Peer) RoutineNonce() {
var keypair *Keypair
device := peer.device
logDebug := device.log.Debug
flush := func() {
for {
select {
case elem := <-peer.queue.nonce:
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
default:
return
} }
} }
}
defer func() {
flush()
logDebug.Println(peer, "- Routine: nonce worker - stopped")
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
peer.routines.stopping.Done()
}() }()
peer.routines.starting.Done()
logDebug.Println(peer, "- Routine: nonce worker - started")
for { for {
// read packets NextPacket:
count, readErr = device.tun.device.Read(bufs, sizes, offset) peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
for i := 0; i < count; i++ {
if sizes[i] < 1 {
continue
}
elem := elems[i] select {
elem.packet = bufs[i][offset : offset+sizes[i]] case <-peer.routines.stop:
return
// lookup peer case <-peer.signals.flushNonceQueue:
var peer *Peer flush()
switch elem.packet[0] >> 4 { goto NextPacket
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.Lookup(dst)
case 6: case elem, ok := <-peer.queue.nonce:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.Lookup(dst)
default:
device.log.Verbosef("Received packet with unknown IP version")
}
if peer == nil {
continue
}
elemsForPeer, ok := elemsByPeer[peer]
if !ok { if !ok {
elemsForPeer = device.GetOutboundElementsContainer() return
elemsByPeer[peer] = elemsForPeer
} }
elemsForPeer.elems = append(elemsForPeer.elems, elem)
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
}
for peer, elemsForPeer := range elemsByPeer { // make sure to always pick the newest key
if peer.isRunning.Load() {
peer.StagePackets(elemsForPeer) for {
peer.SendStagedPackets()
} else { // check validity of newest key pair
for _, elem := range elemsForPeer.elems {
keypair = peer.keypairs.Current()
if keypair != nil && keypair.sendNonce < RejectAfterMessages {
if time.Since(keypair.created) < RejectAfterTime {
break
}
}
peer.queue.packetInNonceQueueIsAwaitingKey.Set(true)
// no suitable key pair, request for new handshake
select {
case <-peer.signals.newKeypairArrived:
default:
}
peer.SendHandshakeInitiation(false)
// wait for key to be established
logDebug.Println(peer, "- Awaiting keypair")
select {
case <-peer.signals.newKeypairArrived:
logDebug.Println(peer, "- Obtained awaited keypair")
case <-peer.signals.flushNonceQueue:
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
flush()
goto NextPacket
case <-peer.routines.stop:
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return
} }
device.PutOutboundElementsContainer(elemsForPeer)
} }
delete(elemsByPeer, peer) peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
}
if readErr != nil { // populate work element
if errors.Is(readErr, tun.ErrTooManySegments) {
// TODO: record stat for this
// This will happen if MSS is surprisingly small (< 576)
// coincident with reasonably high throughput.
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
continue
}
if !device.isClosed() {
if !errors.Is(readErr, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
}
go device.Close()
}
return
}
}
}
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { elem.peer = peer
for { elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
select {
case peer.queue.staged <- elems:
return
default:
}
select {
case tooOld := <-peer.queue.staged:
for _, elem := range tooOld.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(tooOld)
default:
}
}
}
func (peer *Peer) SendStagedPackets() { // double check in case of race condition added by future code
top:
if len(peer.queue.staged) == 0 || !peer.device.isUp() {
return
}
keypair := peer.keypairs.Current() if elem.nonce >= RejectAfterMessages {
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
peer.SendHandshakeInitiation(false) device.PutMessageBuffer(elem.buffer)
return device.PutOutboundElement(elem)
} goto NextPacket
for {
var elemsContainerOOO *QueueOutboundElementsContainer
select {
case elemsContainer := <-peer.queue.staged:
i := 0
for _, elem := range elemsContainer.elems {
elem.peer = peer
elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages {
keypair.sendNonce.Store(RejectAfterMessages)
if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
continue
} else {
elemsContainer.elems[i] = elem
i++
}
elem.keypair = keypair
}
elemsContainer.Lock()
elemsContainer.elems = elemsContainer.elems[:i]
if elemsContainerOOO != nil {
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
} }
if len(elemsContainer.elems) == 0 { elem.keypair = keypair
peer.device.PutOutboundElementsContainer(elemsContainer) elem.dropped = AtomicFalse
goto top elem.Lock()
}
// add to parallel and sequential queue // add to parallel and sequential queue
if peer.isRunning.Load() { addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
peer.queue.outbound.c <- elemsContainer
peer.device.queue.encryption.c <- elemsContainer
} else {
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(elemsContainer)
}
if elemsContainerOOO != nil {
goto top
}
default:
return
} }
} }
} }
func (peer *Peer) FlushStagedPackets() {
for {
select {
case elemsContainer := <-peer.queue.staged:
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
}
}
func calculatePaddingSize(packetSize, mtu int) int {
lastUnit := packetSize
if mtu == 0 {
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
}
if lastUnit > mtu {
lastUnit %= mtu
}
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
if paddedSize > mtu {
paddedSize = mtu
}
return paddedSize - lastUnit
}
/* Encrypts the elements in the queue /* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex) * and marks them for sequential consumption (by releasing the mutex)
* *
* Obs. One instance per core * Obs. One instance per core
*/ */
func (device *Device) RoutineEncryption(id int) { func (device *Device) RoutineEncryption() {
var paddingZeros [PaddingMultiple]byte
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) logDebug := device.log.Debug
device.log.Verbosef("Routine: encryption worker %d - started", id)
defer func() {
for {
select {
case elem, ok := <-device.queue.encryption:
if ok && !elem.IsDropped() {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
elem.Unlock()
}
default:
goto out
}
}
out:
logDebug.Println("Routine: encryption worker - stopped")
device.state.stopping.Done()
}()
logDebug.Println("Routine: encryption worker - started")
device.state.starting.Done()
for {
// fetch next element
select {
case <-device.signals.stop:
return
case elem, ok := <-device.queue.encryption:
if !ok {
return
}
// check if dropped
if elem.IsDropped() {
continue
}
for elemsContainer := range device.queue.encryption.c {
for _, elem := range elemsContainer.elems {
// populate header fields // populate header fields
header := elem.buffer[:MessageTransportHeaderSize] header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4] fieldType := header[0:4]
@ -524,8 +516,16 @@ func (device *Device) RoutineEncryption(id int) {
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16 // pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) mtu := int(atomic.LoadInt32(&device.tun.mtu))
lastUnit := len(elem.packet) % mtu
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
if paddedSize > mtu {
paddedSize = mtu
}
for i := len(elem.packet); i < paddedSize; i++ {
elem.packet = append(elem.packet, 0)
}
// encrypt content and release to consumer // encrypt content and release to consumer
@ -536,73 +536,82 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet, elem.packet,
nil, nil,
) )
elem.Unlock()
} }
elemsContainer.Unlock()
} }
} }
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { /* Sequentially reads packets from queue and sends to endpoint
*
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequentialSender() {
device := peer.device device := peer.device
logDebug := device.log.Debug
logError := device.log.Error
defer func() { defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) for {
peer.stopping.Done() select {
case elem, ok := <-peer.queue.outbound:
if ok {
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
elem.Drop()
}
device.PutOutboundElement(elem)
}
default:
goto out
}
}
out:
logDebug.Println(peer, "- Routine: sequential sender - stopped")
peer.routines.stopping.Done()
}() }()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
bufs := make([][]byte, 0, maxBatchSize) logDebug.Println(peer, "- Routine: sequential sender - started")
for elemsContainer := range peer.queue.outbound.c { peer.routines.starting.Done()
bufs = bufs[:0]
if elemsContainer == nil { for {
select {
case <-peer.routines.stop:
return return
}
if !peer.isRunning.Load() { case elem, ok := <-peer.queue.outbound:
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped if !ok {
// immediately after this check, in which case, elem will get processed. return
// The timers and SendBuffers code are resilient to a few stragglers. }
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary. elem.Lock()
elemsContainer.Lock() if elem.IsDropped() {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
continue
} }
continue
} peer.timersAnyAuthenticatedPacketTraversal()
dataSent := false peer.timersAnyAuthenticatedPacketSent()
elemsContainer.Lock()
for _, elem := range elemsContainer.elems { // send message and return buffer to pool
err := peer.SendBuffer(elem.packet)
if len(elem.packet) != MessageKeepaliveSize { if len(elem.packet) != MessageKeepaliveSize {
dataSent = true peer.timersDataSent()
} }
bufs = append(bufs, elem.packet)
}
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err := peer.SendBuffers(bufs)
if dataSent {
peer.timersDataSent()
}
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
} if err != nil {
device.PutOutboundElementsContainer(elemsContainer) logError.Println(peer, "- Failed to send data packet", err)
if err != nil { continue
var errGSO conn.ErrUDPGSODisabled
if errors.As(err, &errGSO) {
device.log.Verbosef(err.Error())
err = errGSO.RetryErr
} }
}
if err != nil {
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue
}
peer.keepKeyFreshSending() peer.keepKeyFreshSending()
}
} }
} }

View file

@ -1,12 +0,0 @@
//go:build !linux
package device
import (
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
return nil, nil
}

View file

@ -1,224 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
package device
import (
"sync"
"unsafe"
"golang.org/x/sys/unix"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
if !conn.StdNetSupportsStickySockets {
return nil, nil
}
if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil
}
netlinkSock, err := createNetlinkRouteSocket()
if err != nil {
return nil, err
}
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
if err != nil {
unix.Close(netlinkSock)
return nil, err
}
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
return netlinkCancel, nil
}
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
type peerEndpointPtr struct {
peer *Peer
endpoint *conn.Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer netlinkCancel.Close()
defer unix.Close(netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.endpoint.Lock()
if &pePtr.peer.endpoint.val != pePtr.endpoint {
pePtr.peer.endpoint.Unlock()
break
}
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.endpoint.Unlock()
break
}
pePtr.peer.endpoint.clearSrcOnTx = true
pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.endpoint.Lock()
if peer.endpoint.val == nil {
peer.endpoint.Unlock()
continue
}
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
peer.endpoint.Unlock()
continue
}
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.endpoint.Unlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
device.net.fwmark,
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: unix.RTMGRP_IPV4_ROUTE,
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* *
* This is based heavily on timers.c from the kernel implementation. * This is based heavily on timers.c from the kernel implementation.
*/ */
@ -8,16 +8,16 @@
package device package device
import ( import (
"math/rand"
"sync" "sync"
"sync/atomic"
"time" "time"
_ "unsafe"
) )
//go:linkname fastrandn runtime.fastrandn /* This Timer structure and related functions should roughly copy the interface of
func fastrandn(n uint32) uint32 * the Linux kernel's struct timer_list.
*/
// A Timer manages time-based aspects of the WireGuard protocol.
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
type Timer struct { type Timer struct {
*time.Timer *time.Timer
modifyingLock sync.RWMutex modifyingLock sync.RWMutex
@ -29,17 +29,18 @@ func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
timer := &Timer{} timer := &Timer{}
timer.Timer = time.AfterFunc(time.Hour, func() { timer.Timer = time.AfterFunc(time.Hour, func() {
timer.runningLock.Lock() timer.runningLock.Lock()
defer timer.runningLock.Unlock()
timer.modifyingLock.Lock() timer.modifyingLock.Lock()
if !timer.isPending { if !timer.isPending {
timer.modifyingLock.Unlock() timer.modifyingLock.Unlock()
timer.runningLock.Unlock()
return return
} }
timer.isPending = false timer.isPending = false
timer.modifyingLock.Unlock() timer.modifyingLock.Unlock()
expirationFunction(peer) expirationFunction(peer)
timer.runningLock.Unlock()
}) })
timer.Stop() timer.Stop()
return timer return timer
@ -73,12 +74,12 @@ func (timer *Timer) IsPending() bool {
} }
func (peer *Peer) timersActive() bool { func (peer *Peer) timersActive() bool {
return peer.isRunning.Load() && peer.device != nil && peer.device.isUp() return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
} }
func expiredRetransmitHandshake(peer *Peer) { func expiredRetransmitHandshake(peer *Peer) {
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
if peer.timersActive() { if peer.timersActive() {
peer.timers.sendKeepalive.Del() peer.timers.sendKeepalive.Del()
@ -87,7 +88,7 @@ func expiredRetransmitHandshake(peer *Peer) {
/* We drop all packets without a keypair and don't try again, /* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake. * if we try unsuccessfully for too long to make a handshake.
*/ */
peer.FlushStagedPackets() peer.FlushNonceQueue()
/* We set a timer for destroying any residue that might be left /* We set a timer for destroying any residue that might be left
* of a partial exchange. * of a partial exchange.
@ -96,11 +97,15 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
} }
} else { } else {
peer.timers.handshakeAttempts.Add(1) atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.markEndpointSrcForClearing() peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(true) peer.SendHandshakeInitiation(true)
} }
@ -108,8 +113,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) { func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive() peer.SendKeepalive()
if peer.timers.needAnotherKeepalive.Load() { if peer.timers.needAnotherKeepalive.Get() {
peer.timers.needAnotherKeepalive.Store(false) peer.timers.needAnotherKeepalive.Set(false)
if peer.timersActive() { if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} }
@ -117,19 +122,24 @@ func expiredSendKeepalive(peer *Peer) {
} }
func expiredNewHandshake(peer *Peer) { func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.markEndpointSrcForClearing() peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
func expiredZeroKeyMaterial(peer *Peer) { func expiredZeroKeyMaterial(peer *Peer) {
peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
peer.ZeroAndFlushAll() peer.ZeroAndFlushAll()
} }
func expiredPersistentKeepalive(peer *Peer) { func expiredPersistentKeepalive(peer *Peer) {
if peer.persistentKeepaliveInterval.Load() > 0 { if peer.persistentKeepaliveInterval > 0 {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
@ -137,7 +147,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */ /* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() { func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
} }
} }
@ -147,7 +157,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() { if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else { } else {
peer.timers.needAnotherKeepalive.Store(true) peer.timers.needAnotherKeepalive.Set(true)
} }
} }
} }
@ -169,7 +179,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */ /* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() { func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
} }
} }
@ -178,9 +188,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Del() peer.timers.retransmitHandshake.Del()
} }
peer.timers.handshakeAttempts.Store(0) atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
peer.timers.sentLastMinuteHandshake.Store(false) peer.timers.sentLastMinuteHandshake.Set(false)
peer.lastHandshakeNano.Store(time.Now().UnixNano()) atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
} }
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@ -192,9 +202,8 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
keepalive := peer.persistentKeepaliveInterval.Load() if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
if keepalive > 0 && peer.timersActive() { peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
} }
} }
@ -204,12 +213,9 @@ func (peer *Peer) timersInit() {
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
} atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
peer.timers.sentLastMinuteHandshake.Set(false)
func (peer *Peer) timersStart() { peer.timers.needAnotherKeepalive.Set(false)
peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Store(false)
peer.timers.needAnotherKeepalive.Store(false)
} }
func (peer *Peer) timersStop() { func (peer *Peer) timersStop() {

View file

@ -1,53 +1,56 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"fmt" "sync/atomic"
"github.com/amnezia-vpn/amneziawg-go/tun" "golang.zx2c4.com/wireguard/tun"
) )
const DefaultMTU = 1420 const DefaultMTU = 1420
func (device *Device) RoutineTUNEventReader() { func (device *Device) RoutineTUNEventReader() {
device.log.Verbosef("Routine: event worker - started") setUp := false
logDebug := device.log.Debug
logInfo := device.log.Info
logError := device.log.Error
logDebug.Println("Routine: event worker - started")
device.state.starting.Done()
for event := range device.tun.device.Events() { for event := range device.tun.device.Events() {
if event&tun.EventMTUUpdate != 0 { if event&tun.EventMTUUpdate != 0 {
mtu, err := device.tun.device.MTU() mtu, err := device.tun.device.MTU()
old := atomic.LoadInt32(&device.tun.mtu)
if err != nil { if err != nil {
device.log.Errorf("Failed to load updated MTU of device: %v", err) logError.Println("Failed to load updated MTU of device:", err)
continue } else if int(old) != mtu {
} if mtu+MessageTransportSize > MaxMessageSize {
if mtu < 0 { logInfo.Println("MTU updated:", mtu, "(too large)")
device.log.Errorf("MTU not updated to negative value: %v", mtu) } else {
continue logInfo.Println("MTU updated:", mtu)
} }
var tooLarge string atomic.StoreInt32(&device.tun.mtu, int32(mtu))
if mtu > MaxContentSize {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize
}
old := device.tun.mtu.Swap(int32(mtu))
if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
} }
} }
if event&tun.EventUp != 0 { if event&tun.EventUp != 0 && !setUp {
device.log.Verbosef("Interface up requested") logInfo.Println("Interface set up")
setUp = true
device.Up() device.Up()
} }
if event&tun.EventDown != 0 { if event&tun.EventDown != 0 && setUp {
device.log.Verbosef("Interface down requested") logInfo.Println("Interface set down")
setUp = false
device.Down() device.Down()
} }
} }
device.log.Verbosef("Routine: event worker - stopped") logDebug.Println("Routine: event worker - stopped")
device.state.stopping.Done()
} }

56
device/tun_test.go Normal file
View file

@ -0,0 +1,56 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"os"
"golang.zx2c4.com/wireguard/tun"
)
// newDummyTUN creates a dummy TUN device with the specified name.
func newDummyTUN(name string) tun.Device {
return &dummyTUN{
name: name,
packets: make(chan []byte, 100),
events: make(chan tun.Event, 10),
}
}
// A dummyTUN is a tun.Device which is used in unit tests.
type dummyTUN struct {
name string
mtu int
packets chan []byte
events chan tun.Event
}
func (d *dummyTUN) Events() chan tun.Event { return d.events }
func (*dummyTUN) File() *os.File { return nil }
func (*dummyTUN) Flush() error { return nil }
func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil }
func (d *dummyTUN) Name() (string, error) { return d.name, nil }
func (d *dummyTUN) Close() error {
close(d.events)
close(d.packets)
return nil
}
func (d *dummyTUN) Read(b []byte, offset int) (int, error) {
buf, ok := <-d.packets
if !ok {
return 0, errors.New("device closed")
}
copy(b[offset:], buf)
return len(buf), nil
}
func (d *dummyTUN) Write(b []byte, offset int) (int, error) {
d.packets <- b[offset:]
return len(b), nil
}

View file

@ -1,77 +1,43 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"bufio" "bufio"
"bytes"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"sync" "sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/ipc" "golang.zx2c4.com/wireguard/ipc"
) )
type IPCError struct { type IPCError struct {
code int64 // error code int64
err error // underlying/wrapped error
} }
func (s IPCError) Error() string { func (s IPCError) Error() string {
return fmt.Sprintf("IPC error %d: %v", s.code, s.err) return fmt.Sprintf("IPC error: %d", s.int64)
}
func (s IPCError) Unwrap() error {
return s.err
} }
func (s IPCError) ErrorCode() int64 { func (s IPCError) ErrorCode() int64 {
return s.code return s.int64
} }
func ipcErrorf(code int64, msg string, args ...any) *IPCError { func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)} lines := make([]string, 0, 100)
} send := func(line string) {
lines = append(lines, line)
var byteBufferPool = &sync.Pool{
New: func() any { return new(bytes.Buffer) },
}
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
func (device *Device) IpcGetOperation(w io.Writer) error {
device.ipcMutex.RLock()
defer device.ipcMutex.RUnlock()
buf := byteBufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer byteBufferPool.Put(buf)
sendf := func(format string, args ...any) {
fmt.Fprintf(buf, format, args...)
buf.WriteByte('\n')
}
keyf := func(prefix string, key *[32]byte) {
buf.Grow(len(key)*2 + 2 + len(prefix))
buf.WriteString(prefix)
buf.WriteByte('=')
const hex = "0123456789abcdef"
for i := 0; i < len(key); i++ {
buf.WriteByte(hex[key[i]>>4])
buf.WriteByte(hex[key[i]&0xf])
}
buf.WriteByte('\n')
} }
func() { func() {
// lock required resources // lock required resources
device.net.RLock() device.net.RLock()
@ -86,452 +52,337 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
// serialize device related values // serialize device related values
if !device.staticIdentity.privateKey.IsZero() { if !device.staticIdentity.privateKey.IsZero() {
keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey)) send("private_key=" + device.staticIdentity.privateKey.ToHex())
} }
if device.net.port != 0 { if device.net.port != 0 {
sendf("listen_port=%d", device.net.port) send(fmt.Sprintf("listen_port=%d", device.net.port))
} }
if device.net.fwmark != 0 { if device.net.fwmark != 0 {
sendf("fwmark=%d", device.net.fwmark) send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
} }
if device.isAdvancedSecurityOn() { // serialize each peer state
if device.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount)
}
if device.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
}
if device.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
}
if device.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
}
if device.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
}
if device.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
}
if device.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
}
if device.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
}
if device.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
}
}
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
// Serialize peer state. peer.RLock()
peer.handshake.mutex.RLock() defer peer.RUnlock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
peer.endpoint.Lock()
if peer.endpoint.val != nil {
sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
peer.endpoint.Unlock()
nano := peer.lastHandshakeNano.Load() send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
send("protocol_version=1")
if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.DstToString())
}
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
secs := nano / time.Second.Nanoseconds() secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds() nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs) send(fmt.Sprintf("last_handshake_time_sec=%d", secs))
sendf("last_handshake_time_nsec=%d", nano) send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
sendf("tx_bytes=%d", peer.txBytes.Load()) send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
sendf("rx_bytes=%d", peer.rxBytes.Load()) send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.allowedips.EntriesForPeer(peer) {
send("allowed_ip=" + ip.String())
}
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
} }
}() }()
// send lines (does not require resource locks) // send lines (does not require resource locks)
if _, err := w.Write(buf.Bytes()); err != nil {
return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err) for _, line := range lines {
_, err := socket.WriteString(line + "\n")
if err != nil {
return &IPCError{ipc.IpcErrorIO}
}
} }
return nil return nil
} }
// IpcSetOperation implements the WireGuard configuration protocol "set" operation. func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
// See https://www.wireguard.com/xplatform/#configuration-protocol for details. scanner := bufio.NewScanner(socket)
func (device *Device) IpcSetOperation(r io.Reader) (err error) { logError := device.log.Error
device.ipcMutex.Lock() logDebug := device.log.Debug
defer device.ipcMutex.Unlock()
defer func() { var peer *Peer
if err != nil {
device.log.Errorf("%v", err)
}
}()
peer := new(ipcSetPeer) dummy := false
deviceConfig := true deviceConfig := true
tempASecCfg := aSecCfgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
// parse line
line := scanner.Text() line := scanner.Text()
if line == "" { if line == "" {
// Blank line means terminate operation.
err := device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
peer.handlePostConfig()
return nil return nil
} }
key, value, ok := strings.Cut(line, "=") parts := strings.Split(line, "=")
if !ok { if len(parts) != 2 {
return ipcErrorf( return &IPCError{ipc.IpcErrorProtocol}
ipc.IpcErrorProtocol,
"failed to parse line %q",
line,
)
} }
key := parts[0]
value := parts[1]
if key == "public_key" { /* device configuration */
if deviceConfig {
deviceConfig = false
}
peer.handlePostConfig()
// Load/create the peer we are now configuring.
err := device.handlePublicKeyLine(peer, value)
if err != nil {
return err
}
continue
}
var err error
if deviceConfig { if deviceConfig {
err = device.handleDeviceLine(key, value, &tempASecCfg)
} else {
err = device.handlePeerLine(peer, key, value)
}
if err != nil {
return err
}
}
err = device.handlePostConfig(&tempASecCfg)
if err != nil {
return err
}
peer.handlePostConfig()
if err := scanner.Err(); err != nil { switch key {
return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err) case "private_key":
} var sk NoisePrivateKey
return nil err := sk.FromHex(value)
} if err != nil {
logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
logDebug.Println("UAPI: Updating private key")
device.SetPrivateKey(sk)
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { case "listen_port":
switch key {
case "private_key":
var sk NoisePrivateKey
err := sk.FromMaybeZeroHex(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
}
device.log.Verbosef("UAPI: Updating private key")
device.SetPrivateKey(sk)
case "listen_port": // parse port number
port, err := strconv.ParseUint(value, 10, 16)
if err != nil { port, err := strconv.ParseUint(value, 10, 16)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) if err != nil {
logError.Println("Failed to parse listen_port:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
// update port and rebind
logDebug.Println("UAPI: Updating listen port")
device.net.Lock()
device.net.port = uint16(port)
device.net.Unlock()
if err := device.BindUpdate(); err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{ipc.IpcErrorPortInUse}
}
case "fwmark":
// parse fwmark field
fwmark, err := func() (uint32, error) {
if value == "" {
return 0, nil
}
mark, err := strconv.ParseUint(value, 10, 32)
return uint32(mark), err
}()
if err != nil {
logError.Println("Invalid fwmark", err)
return &IPCError{ipc.IpcErrorInvalid}
}
logDebug.Println("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(fwmark)); err != nil {
logError.Println("Failed to update fwmark:", err)
return &IPCError{ipc.IpcErrorPortInUse}
}
case "public_key":
// switch to peer configuration
logDebug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
case "replace_peers":
if value != "true" {
logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
logDebug.Println("UAPI: Removing all peers")
device.RemoveAllPeers()
default:
logError.Println("Invalid UAPI device key:", key)
return &IPCError{ipc.IpcErrorInvalid}
}
} }
// update port and rebind /* peer configuration */
device.log.Verbosef("UAPI: Updating listen port")
device.net.Lock() if !deviceConfig {
device.net.port = uint16(port)
device.net.Unlock()
if err := device.BindUpdate(); err != nil { switch key {
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
case "public_key":
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
if err != nil {
logError.Println("Failed to get peer by public key:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
// ignore peer with public key of device
device.staticIdentity.RLock()
dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.staticIdentity.RUnlock()
if dummy {
peer = &Peer{}
} else {
peer = device.LookupPeer(publicKey)
}
if peer == nil {
peer, err = device.NewPeer(publicKey)
if err != nil {
logError.Println("Failed to create new peer:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
if peer == nil {
dummy = true
peer = &Peer{}
} else {
logDebug.Println(peer, "- UAPI: Created")
}
}
case "remove":
// remove currently selected peer from device
if value != "true" {
logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
if !dummy {
logDebug.Println(peer, "- UAPI: Removing")
device.RemovePeer(peer.handshake.remoteStatic)
}
peer = &Peer{}
dummy = true
case "preshared_key":
// update PSK
logDebug.Println(peer, "- UAPI: Updating preshared key")
peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.handshake.mutex.Unlock()
if err != nil {
logError.Println("Failed to set preshared key:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
case "endpoint":
// set endpoint destination
logDebug.Println(peer, "- UAPI: Updating endpoint")
err := func() error {
peer.Lock()
defer peer.Unlock()
endpoint, err := CreateEndpoint(value)
if err != nil {
return err
}
peer.endpoint = endpoint
return nil
}()
if err != nil {
logError.Println("Failed to set endpoint:", err, ":", value)
return &IPCError{ipc.IpcErrorInvalid}
}
case "persistent_keepalive_interval":
// update persistent keepalive interval
logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to set persistent keepalive interval:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
old := peer.persistentKeepaliveInterval
peer.persistentKeepaliveInterval = uint16(secs)
// send immediate keepalive if we're turning it on and before it wasn't on
if old == 0 && secs != 0 {
if err != nil {
logError.Println("Failed to get tun device status:", err)
return &IPCError{ipc.IpcErrorIO}
}
if device.isUp.Get() && !dummy {
peer.SendKeepalive()
}
}
case "replace_allowed_ips":
logDebug.Println(peer, "- UAPI: Removing all allowedips")
if value != "true" {
logError.Println("Failed to replace allowedips, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
if dummy {
continue
}
device.allowedips.RemoveByPeer(peer)
case "allowed_ip":
logDebug.Println(peer, "- UAPI: Adding allowedip")
_, network, err := net.ParseCIDR(value)
if err != nil {
logError.Println("Failed to set allowed ip:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
if dummy {
continue
}
ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint(ones), peer)
case "protocol_version":
if value != "1" {
logError.Println("Invalid protocol version:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
default:
logError.Println("Invalid UAPI peer key:", key)
return &IPCError{ipc.IpcErrorInvalid}
}
} }
case "fwmark":
mark, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
}
device.log.Verbosef("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(mark)); err != nil {
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
}
case "replace_peers":
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
}
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempASecCfg.junkPacketCount = junkPacketCount
tempASecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempASecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempASecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempASecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempASecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempASecCfg.isSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
} }
return nil return nil
} }
// An ipcSetPeer is the current state of an IPC set operation on a peer.
type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
}
func (peer *ipcSetPeer) handlePostConfig() {
if peer.Peer == nil || peer.dummy {
return
}
if peer.created {
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
}
if peer.device.isUp() {
peer.Start()
if peer.pkaOn {
peer.SendKeepalive()
}
peer.SendStagedPackets()
}
}
func (device *Device) handlePublicKeyLine(
peer *ipcSetPeer,
value string,
) error {
// Load/create the peer we are configuring.
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
}
// Ignore peer with the same public key as this device.
device.staticIdentity.RLock()
peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.staticIdentity.RUnlock()
if peer.dummy {
peer.Peer = &Peer{}
} else {
peer.Peer = device.LookupPeer(publicKey)
}
peer.created = peer.Peer == nil
if peer.created {
peer.Peer, err = device.NewPeer(publicKey)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
}
device.log.Verbosef("%v - UAPI: Created", peer.Peer)
}
return nil
}
func (device *Device) handlePeerLine(
peer *ipcSetPeer,
key, value string,
) error {
switch key {
case "update_only":
// allow disabling of creation
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
}
if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic)
peer.Peer = &Peer{}
peer.dummy = true
}
case "remove":
// remove currently selected peer from device
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
}
if !peer.dummy {
device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
device.RemovePeer(peer.handshake.remoteStatic)
}
peer.Peer = &Peer{}
peer.dummy = true
case "preshared_key":
device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.handshake.mutex.Unlock()
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
}
case "endpoint":
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
endpoint, err := device.net.bind.ParseEndpoint(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
peer.endpoint.Lock()
defer peer.endpoint.Unlock()
peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
}
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on.
peer.pkaOn = old == 0 && secs != 0
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
}
if peer.dummy {
return nil
}
device.allowedips.RemoveByPeer(peer.Peer)
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
}
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
}
return nil
}
func (device *Device) IpcGet() (string, error) {
buf := new(strings.Builder)
if err := device.IpcGetOperation(buf); err != nil {
return "", err
}
return buf.String(), nil
}
func (device *Device) IpcSet(uapiConf string) error {
return device.IpcSetOperation(strings.NewReader(uapiConf))
}
func (device *Device) IpcHandle(socket net.Conn) { func (device *Device) IpcHandle(socket net.Conn) {
// create buffered read/writer
defer socket.Close() defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter { buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@ -540,44 +391,35 @@ func (device *Device) IpcHandle(socket net.Conn) {
return bufio.NewReadWriter(reader, writer) return bufio.NewReadWriter(reader, writer)
}(socket) }(socket)
for { defer buffered.Flush()
op, err := buffered.ReadString('\n')
if err != nil {
return
}
// handle operation op, err := buffered.ReadString('\n')
switch op { if err != nil {
case "set=1\n": return
err = device.IpcSetOperation(buffered.Reader) }
case "get=1\n":
var nextByte byte
nextByte, err = buffered.ReadByte()
if err != nil {
return
}
if nextByte != '\n' {
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
break
}
err = device.IpcGetOperation(buffered.Writer)
default:
device.log.Errorf("invalid UAPI operation: %v", op)
return
}
// write status // handle operation
var status *IPCError
if err != nil && !errors.As(err, &status) { var status *IPCError
// shouldn't happen
status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) switch op {
} case "set=1\n":
if status != nil { status = device.IpcSetOperation(buffered.Reader)
device.log.Errorf("%v", status)
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) case "get=1\n":
} else { status = device.IpcGetOperation(buffered.Writer)
fmt.Fprintf(buffered, "errno=0\n\n")
} default:
buffered.Flush() device.log.Error.Println("Invalid UAPI operation:", op)
return
}
// write status
if status != nil {
device.log.Error.Println(status)
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
} }
} }

3
device/version.go Normal file
View file

@ -0,0 +1,3 @@
package device
const WireGuardGoVersion = "0.0.20190908"

View file

@ -1,51 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"bytes"
"go/format"
"io/fs"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
)
func TestFormatting(t *testing.T) {
var wg sync.WaitGroup
filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
t.Errorf("unable to walk %s: %v", path, err)
return nil
}
if d.IsDir() || filepath.Ext(path) != ".go" {
return nil
}
wg.Add(1)
go func(path string) {
defer wg.Done()
src, err := os.ReadFile(path)
if err != nil {
t.Errorf("unable to read %s: %v", path, err)
return
}
if runtime.GOOS == "windows" {
src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
}
formatted, err := format.Source(src)
if err != nil {
t.Errorf("unable to format %s: %v", path, err)
return
}
if !bytes.Equal(src, formatted) {
t.Errorf("unformatted code: %s", path)
}
}(path)
return nil
})
wg.Wait()
}

19
go.mod
View file

@ -1,17 +1,10 @@
module github.com/amnezia-vpn/amneziawg-go module golang.zx2c4.com/wireguard
go 1.24 go 1.12
require ( require (
github.com/tevino/abool/v2 v2.1.0 golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472
golang.org/x/crypto v0.36.0 golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297
golang.org/x/net v0.37.0 golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad
golang.org/x/sys v0.31.0 golang.org/x/text v0.3.2
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
)
require (
github.com/google/btree v1.1.3 // indirect
golang.org/x/time v0.9.0 // indirect
) )

34
go.sum
View file

@ -1,20 +1,14 @@
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad h1:cCejgArrk10gX6kFqjWeLwXD7aVMqWoRpyUCaaJSggc=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=

View file

@ -1,287 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
package namedpipe
import (
"io"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
type timeoutChan chan struct{}
var (
ioInitOnce sync.Once
ioCompletionPort windows.Handle
)
// ioResult contains the result of an asynchronous IO operation
type ioResult struct {
bytes uint32
err error
}
// ioOperation represents an outstanding asynchronous Win32 IO
type ioOperation struct {
o windows.Overlapped
ch chan ioResult
}
func initIo() {
h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
panic(err)
}
ioCompletionPort = h
go ioCompletionProcessor(h)
}
// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected.
type file struct {
handle windows.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomic.Bool
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
}
type deadlineHandler struct {
setLock sync.Mutex
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
timedout atomic.Bool
}
// makeFile makes a new file from an existing file handle
func makeFile(h windows.Handle) (*file, error) {
f := &file{handle: h}
ioInitOnce.Do(initIo)
_, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
if err != nil {
return nil, err
}
err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
if err != nil {
return nil, err
}
f.readDeadline.channel = make(timeoutChan)
f.writeDeadline.channel = make(timeoutChan)
return f, nil
}
// closeHandle closes the resources associated with a Win32 handle
func (f *file) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
if f.closing.Swap(true) == false {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
windows.CancelIoEx(f.handle, nil)
f.wg.Wait()
// at this point, no new IO can start
windows.Close(f.handle)
f.handle = 0
} else {
f.wgLock.Unlock()
}
}
// Close closes a file.
func (f *file) Close() error {
f.closeHandle()
return nil
}
// prepareIo prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
if f.closing.Load() {
f.wgLock.RUnlock()
return nil, os.ErrClosed
}
f.wg.Add(1)
f.wgLock.RUnlock()
c := &ioOperation{}
c.ch = make(chan ioResult)
return c, nil
}
// ioCompletionProcessor processes completed async IOs forever
func ioCompletionProcessor(h windows.Handle) {
for {
var bytes uint32
var key uintptr
var op *ioOperation
err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
if op == nil {
panic(err)
}
op.ch <- ioResult{bytes, err}
}
}
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != windows.ERROR_IO_PENDING {
return int(bytes), err
}
if f.closing.Load() {
windows.CancelIoEx(f.handle, &c.o)
}
var timeout timeoutChan
if d != nil {
d.channelLock.Lock()
timeout = d.channel
d.channelLock.Unlock()
}
var r ioResult
select {
case r = <-c.ch:
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
if f.closing.Load() {
err = os.ErrClosed
}
} else if err != nil && f.socket {
// err is from Win32. Query the overlapped structure to get the winsock error.
var bytes, flags uint32
err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
}
case <-timeout:
windows.CancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
if err == windows.ERROR_OPERATION_ABORTED {
err = os.ErrDeadlineExceeded
}
}
// runtime.KeepAlive is needed, as c is passed via native
// code to ioCompletionProcessor, c must remain alive
// until the channel read is complete.
runtime.KeepAlive(c)
return int(r.bytes), err
}
// Read reads from a file handle.
func (f *file) Read(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.readDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded
}
var bytes uint32
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b)
// Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF
} else if err == windows.ERROR_BROKEN_PIPE {
return 0, io.EOF
} else {
return n, err
}
}
// Write writes to a file handle.
func (f *file) Write(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.writeDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded
}
var bytes uint32
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b)
return n, err
}
func (f *file) SetReadDeadline(deadline time.Time) error {
return f.readDeadline.set(deadline)
}
func (f *file) SetWriteDeadline(deadline time.Time) error {
return f.writeDeadline.set(deadline)
}
func (f *file) Flush() error {
return windows.FlushFileBuffers(f.handle)
}
func (f *file) Fd() uintptr {
return uintptr(f.handle)
}
func (d *deadlineHandler) set(deadline time.Time) error {
d.setLock.Lock()
defer d.setLock.Unlock()
if d.timer != nil {
if !d.timer.Stop() {
<-d.channel
}
d.timer = nil
}
d.timedout.Store(false)
select {
case <-d.channel:
d.channelLock.Lock()
d.channel = make(chan struct{})
d.channelLock.Unlock()
default:
}
if deadline.IsZero() {
return nil
}
timeoutIO := func() {
d.timedout.Store(true)
close(d.channel)
}
now := time.Now()
duration := deadline.Sub(now)
if deadline.After(now) {
// Deadline is in the future, set a timer to wait
d.timer = time.AfterFunc(duration, timeoutIO)
} else {
// Deadline is in the past. Cancel all pending IO now.
timeoutIO()
}
return nil
}

View file

@ -1,485 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
package namedpipe
import (
"context"
"io"
"net"
"os"
"runtime"
"sync/atomic"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
type pipe struct {
*file
path string
}
type messageBytePipe struct {
pipe
writeClosed atomic.Bool
readEOF bool
}
type pipeAddress string
func (f *pipe) LocalAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *pipe) RemoteAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *pipe) SetDeadline(t time.Time) error {
f.SetReadDeadline(t)
f.SetWriteDeadline(t)
return nil
}
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error {
if !f.writeClosed.CompareAndSwap(false, true) {
return io.ErrClosedPipe
}
err := f.file.Flush()
if err != nil {
f.writeClosed.Store(false)
return err
}
_, err = f.file.Write(nil)
if err != nil {
f.writeClosed.Store(false)
return err
}
return nil
}
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed.Load() {
return 0, io.ErrClosedPipe
}
if len(b) == 0 {
return 0, nil
}
return f.file.Write(b)
}
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
// mode pipe will return io.EOF, as will all subsequent reads.
func (f *messageBytePipe) Read(b []byte) (int, error) {
if f.readEOF {
return 0, io.EOF
}
n, err := f.file.Read(b)
if err == io.EOF {
// If this was the result of a zero-byte read, then
// it is possible that the read was due to a zero-size
// message. Since we are simulating CloseWrite with a
// zero-byte message, ensure that all future Read calls
// also return EOF.
f.readEOF = true
} else if err == windows.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams.
err = nil
}
return n, err
}
func (f *pipe) Handle() windows.Handle {
return f.handle
}
func (s pipeAddress) Network() string {
return "pipe"
}
func (s pipeAddress) String() string {
return string(s)
}
// tryDialPipe attempts to dial the specified pipe until cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
for {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
path16, err := windows.UTF16PtrFromString(*path)
if err != nil {
return 0, err
}
h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
if err == nil {
return h, nil
}
if err != windows.ERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path}
}
// Wait 10 msec and try again. This is a rather simplistic
// view, as we always try each 10 milliseconds.
time.Sleep(10 * time.Millisecond)
}
}
}
// DialConfig exposes various options for use in Dial and DialContext.
type DialConfig struct {
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
}
// DialTimeout connects to the specified named pipe by path, timing out if the
// connection takes longer than the specified duration. If timeout is zero, then
// we use a default timeout of 2 seconds.
func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
if timeout == 0 {
timeout = time.Second * 2
}
absTimeout := time.Now().Add(timeout)
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := config.DialContext(ctx, path)
if err == context.DeadlineExceeded {
return nil, os.ErrDeadlineExceeded
}
return conn, err
}
// DialContext attempts to connect to the specified named pipe by path.
func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
var err error
var h windows.Handle
h, err = tryDialPipe(ctx, &path)
if err != nil {
return nil, err
}
if config.ExpectedOwner != nil {
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
if err != nil {
windows.Close(h)
return nil, err
}
realOwner, _, err := sd.Owner()
if err != nil {
windows.Close(h)
return nil, err
}
if !realOwner.Equals(config.ExpectedOwner) {
windows.Close(h)
return nil, windows.ERROR_ACCESS_DENIED
}
}
var flags uint32
err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
windows.Close(h)
return nil, err
}
f, err := makeFile(h)
if err != nil {
windows.Close(h)
return nil, err
}
// If the pipe is in message mode, return a message byte pipe, which
// supports CloseWrite.
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
return &messageBytePipe{
pipe: pipe{file: f, path: path},
}, nil
}
return &pipe{file: f, path: path}, nil
}
var defaultDialer DialConfig
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialTimeout(path, timeout)
}
// DialContext calls DialConfig.DialContext using an empty configuration.
func DialContext(ctx context.Context, path string) (net.Conn, error) {
return defaultDialer.DialContext(ctx, path)
}
type acceptResponse struct {
f *file
err error
}
type pipeListener struct {
firstHandle windows.Handle
path string
config ListenConfig
acceptCh chan chan acceptResponse
closeCh chan int
doneCh chan int
}
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
path16, err := windows.UTF16PtrFromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
var oa windows.OBJECT_ATTRIBUTES
oa.Length = uint32(unsafe.Sizeof(oa))
var ntPath windows.NTUnicodeString
if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil {
if ntstatus, ok := err.(windows.NTStatus); ok {
err = ntstatus.Errno()
}
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer)))
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
if isFirstPipe {
if sd != nil {
oa.SecurityDescriptor = sd
} else {
// Construct the default named pipe security descriptor.
var acl *windows.ACL
if err := windows.RtlDefaultNpAcl(&acl); err != nil {
return 0, err
}
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
sd, err = windows.NewSecurityDescriptor()
if err != nil {
return 0, err
}
if err = sd.SetDACL(acl, true, false); err != nil {
return 0, err
}
oa.SecurityDescriptor = sd
}
}
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
if c.MessageMode {
typ |= windows.FILE_PIPE_MESSAGE_TYPE
}
disposition := uint32(windows.FILE_OPEN)
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if isFirstPipe {
disposition = windows.FILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
// client connections until the next call with isFirstPipe == false.
access = windows.SYNCHRONIZE
}
timeout := int64(-50 * 10000) // 50ms
var (
h windows.Handle
iosb windows.IO_STATUS_BLOCK
)
err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout)
if err != nil {
if ntstatus, ok := err.(windows.NTStatus); ok {
err = ntstatus.Errno()
}
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
runtime.KeepAlive(ntPath)
return h, nil
}
func (l *pipeListener) makeServerPipe() (*file, error) {
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
if err != nil {
return nil, err
}
f, err := makeFile(h)
if err != nil {
windows.Close(h)
return nil, err
}
return f, nil
}
func (l *pipeListener) makeConnectedServerPipe() (*file, error) {
p, err := l.makeServerPipe()
if err != nil {
return nil, err
}
// Wait for the client to connect.
ch := make(chan error)
go func(p *file) {
ch <- connectPipe(p)
}(p)
select {
case err = <-ch:
if err != nil {
p.Close()
p = nil
}
case <-l.closeCh:
// Abort the connect request by closing the handle.
p.Close()
p = nil
err = <-ch
if err == nil || err == os.ErrClosed {
err = net.ErrClosed
}
}
return p, err
}
func (l *pipeListener) listenerRoutine() {
closed := false
for !closed {
select {
case <-l.closeCh:
closed = true
case responseCh := <-l.acceptCh:
var (
p *file
err error
)
for {
p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try
// again.
if err != windows.ERROR_NO_DATA {
break
}
}
responseCh <- acceptResponse{p, err}
closed = err == net.ErrClosed
}
}
windows.Close(l.firstHandle)
l.firstHandle = 0
// Notify Close and Accept callers that the handle has been closed.
close(l.doneCh)
}
// ListenConfig contains configuration for the pipe listener.
type ListenConfig struct {
// SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used.
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
// MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in
// this implementation is that CloseWrite is only supported for message mode pipes;
// CloseWrite is implemented as a zero-byte write, but zero-byte writes are only
// transferred to the reader (and returned as io.EOF in this implementation)
// when the pipe is in message mode.
MessageMode bool
// InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed.
InputBufferSize int32
// OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed.
OutputBufferSize int32
}
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
// The pipe must not already exist.
func (c *ListenConfig) Listen(path string) (net.Listener, error) {
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
if err != nil {
return nil, err
}
l := &pipeListener{
firstHandle: h,
path: path,
config: *c,
acceptCh: make(chan chan acceptResponse),
closeCh: make(chan int),
doneCh: make(chan int),
}
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
path16, err := windows.UTF16PtrFromString(path)
if err == nil {
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
if err == nil {
windows.CloseHandle(h)
}
}
}
go l.listenerRoutine()
return l, nil
}
var defaultListener ListenConfig
// Listen calls ListenConfig.Listen using an empty configuration.
func Listen(path string) (net.Listener, error) {
return defaultListener.Listen(path)
}
func connectPipe(p *file) error {
c, err := p.prepareIo()
if err != nil {
return err
}
defer p.wg.Done()
err = windows.ConnectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
return err
}
return nil
}
func (l *pipeListener) Accept() (net.Conn, error) {
ch := make(chan acceptResponse)
select {
case l.acceptCh <- ch:
response := <-ch
err := response.err
if err != nil {
return nil, err
}
if l.config.MessageMode {
return &messageBytePipe{
pipe: pipe{file: response.f, path: l.path},
}, nil
}
return &pipe{file: response.f, path: l.path}, nil
case <-l.doneCh:
return nil, net.ErrClosed
}
}
func (l *pipeListener) Close() error {
select {
case l.closeCh <- 1:
<-l.doneCh
case <-l.doneCh:
}
return nil
}
func (l *pipeListener) Addr() net.Addr {
return pipeAddress(l.path)
}

View file

@ -1,674 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
package namedpipe_test
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net"
"os"
"sync"
"syscall"
"testing"
"time"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
"golang.org/x/sys/windows"
)
func randomPipePath() string {
guid, err := windows.GenerateGUID()
if err != nil {
panic(err)
}
return `\\.\PIPE\go-namedpipe-test-` + guid.String()
}
func TestPingPong(t *testing.T) {
const (
ping = 42
pong = 24
)
pipePath := randomPipePath()
listener, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatalf("unable to listen on pipe: %v", err)
}
defer listener.Close()
go func() {
incoming, err := listener.Accept()
if err != nil {
t.Fatalf("unable to accept pipe connection: %v", err)
}
defer incoming.Close()
var data [1]byte
_, err = incoming.Read(data[:])
if err != nil {
t.Fatalf("unable to read ping from pipe: %v", err)
}
if data[0] != ping {
t.Fatalf("expected ping, got %d", data[0])
}
data[0] = pong
_, err = incoming.Write(data[:])
if err != nil {
t.Fatalf("unable to write pong to pipe: %v", err)
}
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatalf("unable to dial pipe: %v", err)
}
defer client.Close()
client.SetDeadline(time.Now().Add(time.Second * 5))
var data [1]byte
data[0] = ping
_, err = client.Write(data[:])
if err != nil {
t.Fatalf("unable to write ping to pipe: %v", err)
}
_, err = client.Read(data[:])
if err != nil {
t.Fatalf("unable to read pong from pipe: %v", err)
}
if data[0] != pong {
t.Fatalf("expected pong, got %d", data[0])
}
}
func TestDialUnknownFailsImmediately(t *testing.T) {
_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
if !errors.Is(err, syscall.ENOENT) {
t.Fatalf("expected ENOENT got %v", err)
}
}
func TestDialListenerTimesOut(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func TestDialContextListenerTimesOut(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
d := 10 * time.Millisecond
ctx, _ := context.WithTimeout(context.Background(), d)
pipe, err := namedpipe.DialContext(ctx, pipePath)
if err == nil {
pipe.Close()
}
if err != context.DeadlineExceeded {
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
}
}
func TestDialListenerGetsCancelled(t *testing.T) {
pipePath := randomPipePath()
ctx, cancel := context.WithCancel(context.Background())
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
ch := make(chan error)
go func(ctx context.Context, ch chan error) {
_, err := namedpipe.DialContext(ctx, pipePath)
ch <- err
}(ctx, ch)
time.Sleep(time.Millisecond * 30)
cancel()
err = <-ch
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
t.Skip("dacls on named pipes are broken on wine")
}
pipePath := randomPipePath()
sd, _ := windows.SecurityDescriptorFromString("D:")
l, err := (&namedpipe.ListenConfig{
SecurityDescriptor: sd,
}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
pipe.Close()
}
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
}
}
func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
pipePath := randomPipePath()
if cfg == nil {
cfg = &namedpipe.ListenConfig{}
}
l, err := cfg.Listen(pipePath)
if err != nil {
return
}
defer l.Close()
type response struct {
c net.Conn
err error
}
ch := make(chan response)
go func() {
c, err := l.Accept()
ch <- response{c, err}
}()
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
return
}
r := <-ch
if err = r.err; err != nil {
c.Close()
return
}
client = c
server = r.c
return
}
func TestReadTimeout(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
buf := make([]byte, 10)
_, err = c.Read(buf)
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func server(l net.Listener, ch chan int) {
c, err := l.Accept()
if err != nil {
panic(err)
}
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
s, err := rw.ReadString('\n')
if err != nil {
panic(err)
}
_, err = rw.WriteString("got " + s)
if err != nil {
panic(err)
}
err = rw.Flush()
if err != nil {
panic(err)
}
c.Close()
ch <- 1
}
func TestFullListenDialReadWrite(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
ch := make(chan int)
go server(l, ch)
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer c.Close()
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
_, err = rw.WriteString("hello world\n")
if err != nil {
t.Fatal(err)
}
err = rw.Flush()
if err != nil {
t.Fatal(err)
}
s, err := rw.ReadString('\n')
if err != nil {
t.Fatal(err)
}
ms := "got hello world\n"
if s != ms {
t.Errorf("expected '%s', got '%s'", ms, s)
}
<-ch
}
func TestCloseAbortsListen(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
ch := make(chan error)
go func() {
_, err := l.Accept()
ch <- err
}()
time.Sleep(30 * time.Millisecond)
l.Close()
err = <-ch
if err != net.ErrClosed {
t.Fatalf("expected net.ErrClosed, got %v", err)
}
}
func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
b := make([]byte, 10)
w.Close()
n, err := r.Read(b)
if n > 0 {
t.Errorf("unexpected byte count %d", n)
}
if err != io.EOF {
t.Errorf("expected EOF: %v", err)
}
}
func TestCloseClientEOFServer(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
ensureEOFOnClose(t, c, s)
}
func TestCloseServerEOFClient(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
ensureEOFOnClose(t, s, c)
}
func TestCloseWriteEOF(t *testing.T) {
cfg := &namedpipe.ListenConfig{
MessageMode: true,
}
c, s, err := getConnection(cfg)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
type closeWriter interface {
CloseWrite() error
}
err = c.(closeWriter).CloseWrite()
if err != nil {
t.Fatal(err)
}
b := make([]byte, 10)
_, err = s.Read(b)
if err != io.EOF {
t.Fatal(err)
}
}
func TestAcceptAfterCloseFails(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
l.Close()
_, err = l.Accept()
if err != net.ErrClosed {
t.Fatalf("expected net.ErrClosed, got %v", err)
}
}
func TestDialTimesOutByDefault(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func TestTimeoutPendingRead(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
serverDone := make(chan struct{})
go func() {
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
s.Close()
close(serverDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer client.Close()
clientErr := make(chan error)
go func() {
buf := make([]byte, 10)
_, err = client.Read(buf)
clientErr <- err
}()
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
client.SetReadDeadline(time.Unix(1, 0))
select {
case err = <-clientErr:
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timed out while waiting for read to cancel")
<-clientErr
}
<-serverDone
}
func TestTimeoutPendingWrite(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
serverDone := make(chan struct{})
go func() {
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
s.Close()
close(serverDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer client.Close()
clientErr := make(chan error)
go func() {
_, err = client.Write([]byte("this should timeout"))
clientErr <- err
}()
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
client.SetWriteDeadline(time.Unix(1, 0))
select {
case err = <-clientErr:
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timed out while waiting for write to cancel")
<-clientErr
}
<-serverDone
}
type CloseWriter interface {
CloseWrite() error
}
func TestEchoWithMessaging(t *testing.T) {
pipePath := randomPipePath()
l, err := (&namedpipe.ListenConfig{
MessageMode: true, // Use message mode so that CloseWrite() is supported
InputBufferSize: 65536, // Use 64KB buffers to improve performance
OutputBufferSize: 65536,
}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
listenerDone := make(chan bool)
clientDone := make(chan bool)
go func() {
// server echo
conn, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
_, err = io.Copy(conn, conn)
if err != nil {
t.Fatal(err)
}
conn.(CloseWriter).CloseWrite()
close(listenerDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Second)
if err != nil {
t.Fatal(err)
}
defer client.Close()
go func() {
// client read back
bytes := make([]byte, 2)
n, e := client.Read(bytes)
if e != nil {
t.Fatal(e)
}
if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
t.Fatalf("expected 2 bytes, got %v", n)
}
close(clientDone)
}()
payload := make([]byte, 2)
payload[0] = 0
payload[1] = 1
n, err := client.Write(payload)
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("expected 2 bytes, got %v", n)
}
client.(CloseWriter).CloseWrite()
<-listenerDone
<-clientDone
}
func TestConnectRace(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
go func() {
for {
s, err := l.Accept()
if err == net.ErrClosed {
return
}
if err != nil {
t.Fatal(err)
}
s.Close()
}
}()
for i := 0; i < 1000; i++ {
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
c.Close()
}
}
func TestMessageReadMode(t *testing.T) {
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
t.Skipf("Skipping on Windows %d", maj)
}
var wg sync.WaitGroup
defer wg.Wait()
pipePath := randomPipePath()
l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := ([]byte)("hello world")
wg.Add(1)
go func() {
defer wg.Done()
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
_, err = s.Write(msg)
if err != nil {
t.Fatal(err)
}
s.Close()
}()
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer c.Close()
mode := uint32(windows.PIPE_READMODE_MESSAGE)
err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
if err != nil {
t.Fatal(err)
}
ch := make([]byte, 1)
var vmsg []byte
for {
n, err := c.Read(ch)
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Fatalf("expected 1, got %d", n)
}
vmsg = append(vmsg, ch[0])
}
if !bytes.Equal(msg, vmsg) {
t.Fatalf("expected %s, got %s", msg, vmsg)
}
}
func TestListenConnectRace(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long race test")
}
pipePath := randomPipePath()
for i := 0; i < 50 && !t.Failed(); i++ {
var wg sync.WaitGroup
wg.Add(1)
go func() {
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
c.Close()
}
wg.Done()
}()
s, err := namedpipe.Listen(pipePath)
if err != nil {
t.Error(i, err)
} else {
s.Close()
}
wg.Wait()
}
}

View file

@ -1,21 +1,33 @@
//go:build darwin || freebsd || openbsd // +build darwin freebsd openbsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"os" "os"
"path"
"unsafe" "unsafe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
var socketDirectory = "/var/run/wireguard"
const (
IpcErrorIO = -int64(unix.EIO)
IpcErrorProtocol = -int64(unix.EPROTO)
IpcErrorInvalid = -int64(unix.EINVAL)
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
socketName = "%s.sock"
)
type UAPIListener struct { type UAPIListener struct {
listener net.Listener // unix socket listener listener net.Listener // unix socket listener
connNew chan net.Conn connNew chan net.Conn
@ -54,6 +66,7 @@ func (l *UAPIListener) Addr() net.Addr {
} }
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
listener, err := net.FileListener(file) listener, err := net.FileListener(file)
@ -71,7 +84,10 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
unixListener.SetUnlinkOnClose(true) unixListener.SetUnlinkOnClose(true)
} }
socketPath := sockPath(name) socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
// watch for deletion of socket // watch for deletion of socket
@ -103,7 +119,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err l.connErr <- err
return return
} }
if (kerr != nil || n != 1) && kerr != unix.EINTR { if kerr != nil || n != 1 {
if kerr != nil { if kerr != nil {
l.connErr <- kerr l.connErr <- kerr
} else { } else {
@ -130,3 +146,58 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil return uapi, nil
} }
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0755)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0077)
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
unix.Umask(oldUmask)
if err != nil {
return nil, err
}
return listener.File()
}

View file

@ -1,16 +1,29 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
import ( import (
"errors"
"fmt"
"net" "net"
"os" "os"
"path"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
var socketDirectory = "/var/run/wireguard"
const (
IpcErrorIO = -int64(unix.EIO)
IpcErrorProtocol = -int64(unix.EPROTO)
IpcErrorInvalid = -int64(unix.EINVAL)
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
socketName = "%s.sock"
) )
type UAPIListener struct { type UAPIListener struct {
@ -51,6 +64,7 @@ func (l *UAPIListener) Addr() net.Addr {
} }
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
listener, err := net.FileListener(file) listener, err := net.FileListener(file)
@ -70,7 +84,10 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
// watch for deletion of socket // watch for deletion of socket
socketPath := sockPath(name) socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
uapi.inotifyFd, err = unix.InotifyInit() uapi.inotifyFd, err = unix.InotifyInit()
if err != nil { if err != nil {
@ -96,15 +113,14 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
} }
go func(l *UAPIListener) { go func(l *UAPIListener) {
var buf [0]byte var buff [0]byte
for { for {
defer uapi.inotifyRWCancel.Close()
// start with lstat to avoid race condition // start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) { if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err l.connErr <- err
return return
} }
_, err := uapi.inotifyRWCancel.Read(buf[:]) _, err := uapi.inotifyRWCancel.Read(buff[:])
if err != nil { if err != nil {
l.connErr <- err l.connErr <- err
return return
@ -127,3 +143,58 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil return uapi, nil
} }
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0755)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0077)
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
unix.Umask(oldUmask)
if err != nil {
return nil, err
}
return listener.File()
}

View file

@ -1,66 +0,0 @@
//go:build linux || darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
"errors"
"fmt"
"net"
"os"
"golang.org/x/sys/unix"
)
const (
IpcErrorIO = -int64(unix.EIO)
IpcErrorProtocol = -int64(unix.EPROTO)
IpcErrorInvalid = -int64(unix.EINVAL)
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
IpcErrorUnknown = -55 // ENOANO
)
// socketDirectory is variable because it is modified by a linker
// flag in wireguard-android.
var socketDirectory = "/var/run/amneziawg"
func sockPath(iface string) string {
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
}
func UAPIOpen(name string) (*os.File, error) {
if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
return nil, err
}
socketPath := sockPath(name)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0o077)
defer unix.Umask(oldUmask)
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener.File()
}
// Test socket, if not in use cleanup and try again.
if _, err := net.Dial("unix", socketPath); err == nil {
return nil, errors.New("unix socket in use")
}
if err := os.Remove(socketPath); err != nil {
return nil, err
}
listener, err = net.ListenUnix("unix", addr)
if err != nil {
return nil, err
}
return listener.File()
}

View file

@ -1,15 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
// Made up sentinel error codes for {js,wasip1}/wasm.
const (
IpcErrorIO = 1
IpcErrorInvalid = 2
IpcErrorPortInUse = 3
IpcErrorUnknown = 4
IpcErrorProtocol = 5
)

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@ -8,8 +8,7 @@ package ipc
import ( import (
"net" "net"
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe" "golang.zx2c4.com/wireguard/ipc/winpipe"
"golang.org/x/sys/windows"
) )
// TODO: replace these with actual standard windows error numbers from the win package // TODO: replace these with actual standard windows error numbers from the win package
@ -18,7 +17,6 @@ const (
IpcErrorProtocol = -int64(71) IpcErrorProtocol = -int64(71)
IpcErrorInvalid = -int64(22) IpcErrorInvalid = -int64(22)
IpcErrorPortInUse = -int64(98) IpcErrorPortInUse = -int64(98)
IpcErrorUnknown = -int64(55)
) )
type UAPIListener struct { type UAPIListener struct {
@ -49,20 +47,14 @@ func (l *UAPIListener) Addr() net.Addr {
return l.listener.Addr() return l.listener.Addr()
} }
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR /* SDDL_DEVOBJ_SYS_ALL from the WDK */
var UAPISecurityDescriptor = "O:SYD:P(A;;GA;;;SY)"
func init() {
var err error
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
if err != nil {
panic(err)
}
}
func UAPIListen(name string) (net.Listener, error) { func UAPIListen(name string) (net.Listener, error) {
listener, err := (&namedpipe.ListenConfig{ config := winpipe.PipeConfig{
SecurityDescriptor: UAPISecurityDescriptor, SecurityDescriptor: UAPISecurityDescriptor,
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name) }
listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

322
ipc/winpipe/file.go Normal file
View file

@ -0,0 +1,322 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"errors"
"io"
"runtime"
"sync"
"sync/atomic"
"syscall"
"time"
)
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
type atomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
func (b *atomicBool) swap(new bool) bool {
var newInt int32
if new {
newInt = 1
}
return atomic.SwapInt32((*int32)(b), newInt) == 1
}
const (
cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
cFILE_SKIP_SET_EVENT_ON_HANDLE = 2
)
var (
ErrFileClosed = errors.New("file has already been closed")
ErrTimeout = &timeoutError{}
)
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{}
var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle
// ioResult contains the result of an asynchronous IO operation
type ioResult struct {
bytes uint32
err error
}
// ioOperation represents an outstanding asynchronous Win32 IO
type ioOperation struct {
o syscall.Overlapped
ch chan ioResult
}
func initIo() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
if err != nil {
panic(err)
}
ioCompletionPort = h
go ioCompletionProcessor(h)
}
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct {
handle syscall.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomicBool
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
}
type deadlineHandler struct {
setLock sync.Mutex
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
timedout atomicBool
}
// makeWin32File makes a new win32File from an existing file handle
func makeWin32File(h syscall.Handle) (*win32File, error) {
f := &win32File{handle: h}
ioInitOnce.Do(initIo)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
if err != nil {
return nil, err
}
err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE)
if err != nil {
return nil, err
}
f.readDeadline.channel = make(timeoutChan)
f.writeDeadline.channel = make(timeoutChan)
return f, nil
}
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
return makeWin32File(h)
}
// closeHandle closes the resources associated with a Win32 handle
func (f *win32File) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
if !f.closing.swap(true) {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
cancelIoEx(f.handle, nil)
f.wg.Wait()
// at this point, no new IO can start
syscall.Close(f.handle)
f.handle = 0
} else {
f.wgLock.Unlock()
}
}
// Close closes a win32File.
func (f *win32File) Close() error {
f.closeHandle()
return nil
}
// prepareIo prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *win32File) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
if f.closing.isSet() {
f.wgLock.RUnlock()
return nil, ErrFileClosed
}
f.wg.Add(1)
f.wgLock.RUnlock()
c := &ioOperation{}
c.ch = make(chan ioResult)
return c, nil
}
// ioCompletionProcessor processes completed async IOs forever
func ioCompletionProcessor(h syscall.Handle) {
for {
var bytes uint32
var key uintptr
var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
if op == nil {
panic(err)
}
op.ch <- ioResult{bytes, err}
}
}
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING {
return int(bytes), err
}
if f.closing.isSet() {
cancelIoEx(f.handle, &c.o)
}
var timeout timeoutChan
if d != nil {
d.channelLock.Lock()
timeout = d.channel
d.channelLock.Unlock()
}
var r ioResult
select {
case r = <-c.ch:
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED {
if f.closing.isSet() {
err = ErrFileClosed
}
} else if err != nil && f.socket {
// err is from Win32. Query the overlapped structure to get the winsock error.
var bytes, flags uint32
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
}
case <-timeout:
cancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED {
err = ErrTimeout
}
}
// runtime.KeepAlive is needed, as c is passed via native
// code to ioCompletionProcessor, c must remain alive
// until the channel read is complete.
runtime.KeepAlive(c)
return int(r.bytes), err
}
// Read reads from a file handle.
func (f *win32File) Read(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.readDeadline.timedout.isSet() {
return 0, ErrTimeout
}
var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b)
// Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE {
return 0, io.EOF
} else {
return n, err
}
}
// Write writes to a file handle.
func (f *win32File) Write(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.writeDeadline.timedout.isSet() {
return 0, ErrTimeout
}
var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b)
return n, err
}
func (f *win32File) SetReadDeadline(deadline time.Time) error {
return f.readDeadline.set(deadline)
}
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
return f.writeDeadline.set(deadline)
}
func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle)
}
func (f *win32File) Fd() uintptr {
return uintptr(f.handle)
}
func (d *deadlineHandler) set(deadline time.Time) error {
d.setLock.Lock()
defer d.setLock.Unlock()
if d.timer != nil {
if !d.timer.Stop() {
<-d.channel
}
d.timer = nil
}
d.timedout.setFalse()
select {
case <-d.channel:
d.channelLock.Lock()
d.channel = make(chan struct{})
d.channelLock.Unlock()
default:
}
if deadline.IsZero() {
return nil
}
timeoutIO := func() {
d.timedout.setTrue()
close(d.channel)
}
now := time.Now()
duration := deadline.Sub(now)
if deadline.After(now) {
// Deadline is in the future, set a timer to wait
d.timer = time.AfterFunc(duration, timeoutIO)
} else {
// Deadline is in the past. Cancel all pending IO now.
timeoutIO()
}
return nil
}

9
ipc/winpipe/mksyscall.go Normal file
View file

@ -0,0 +1,9 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go pipe.go sd.go file.go

532
ipc/winpipe/pipe.go Normal file
View file

@ -0,0 +1,532 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"syscall"
"time"
"unsafe"
)
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
type ioStatusBlock struct {
Status, Information uintptr
}
type objectAttributes struct {
Length uintptr
RootDirectory uintptr
ObjectName *unicodeString
Attributes uintptr
SecurityDescriptor *securityDescriptor
SecurityQoS uintptr
}
type unicodeString struct {
Length uint16
MaximumLength uint16
Buffer uintptr
}
type securityDescriptor struct {
Revision byte
Sbz1 byte
Control uint16
Owner uintptr
Group uintptr
Sacl uintptr
Dacl uintptr
}
type ntstatus int32
func (status ntstatus) Err() error {
if status >= 0 {
return nil
}
return rtlNtStatusToDosError(status)
}
const (
cERROR_PIPE_BUSY = syscall.Errno(231)
cERROR_NO_DATA = syscall.Errno(232)
cERROR_PIPE_CONNECTED = syscall.Errno(535)
cERROR_SEM_TIMEOUT = syscall.Errno(121)
cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0
cPIPE_TYPE_MESSAGE = 4
cPIPE_READMODE_MESSAGE = 2
cFILE_OPEN = 1
cFILE_CREATE = 2
cFILE_PIPE_MESSAGE_TYPE = 1
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
cSE_DACL_PRESENT = 4
)
var (
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
// This error should match net.errClosing since docker takes a dependency on its text.
ErrPipeListenerClosed = errors.New("use of closed network connection")
errPipeWriteClosed = errors.New("pipe has been closed for write")
)
type win32Pipe struct {
*win32File
path string
}
type win32MessageBytePipe struct {
win32Pipe
writeClosed bool
readEOF bool
}
type pipeAddress string
func (f *win32Pipe) LocalAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *win32Pipe) RemoteAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *win32Pipe) SetDeadline(t time.Time) error {
f.SetReadDeadline(t)
f.SetWriteDeadline(t)
return nil
}
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *win32MessageBytePipe) CloseWrite() error {
if f.writeClosed {
return errPipeWriteClosed
}
err := f.win32File.Flush()
if err != nil {
return err
}
_, err = f.win32File.Write(nil)
if err != nil {
return err
}
f.writeClosed = true
return nil
}
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite().
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed {
return 0, errPipeWriteClosed
}
if len(b) == 0 {
return 0, nil
}
return f.win32File.Write(b)
}
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
// mode pipe will return io.EOF, as will all subsequent reads.
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
if f.readEOF {
return 0, io.EOF
}
n, err := f.win32File.Read(b)
if err == io.EOF {
// If this was the result of a zero-byte read, then
// it is possible that the read was due to a zero-size
// message. Since we are simulating CloseWrite with a
// zero-byte message, ensure that all future Read() calls
// also return EOF.
f.readEOF = true
} else if err == syscall.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams.
err = nil
}
return n, err
}
func (s pipeAddress) Network() string {
return "pipe"
}
func (s pipeAddress) String() string {
return string(s)
}
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
for {
select {
case <-ctx.Done():
return syscall.Handle(0), ctx.Err()
default:
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err == nil {
return h, nil
}
if err != cERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path}
}
// Wait 10 msec and try again. This is a rather simplistic
// view, as we always try each 10 milliseconds.
time.Sleep(time.Millisecond * 10)
}
}
}
// DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) {
var absTimeout time.Time
if timeout != nil {
absTimeout = time.Now().Add(*timeout)
} else {
absTimeout = time.Now().Add(time.Second * 2)
}
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := DialPipeContext(ctx, path, expectedOwner)
if err == context.DeadlineExceeded {
return nil, ErrTimeout
}
return conn, err
}
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout.
func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) {
var err error
var h syscall.Handle
h, err = tryDialPipe(ctx, &path)
if err != nil {
return nil, err
}
if expectedOwner != nil {
var realOwner *syscall.SID
var realSd uintptr
err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd)
if err != nil {
syscall.Close(h)
return nil, err
}
defer localFree(realSd)
if !equalSid(realOwner, expectedOwner) {
syscall.Close(h)
return nil, syscall.ERROR_ACCESS_DENIED
}
}
var flags uint32
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
syscall.Close(h)
return nil, err
}
f, err := makeWin32File(h)
if err != nil {
syscall.Close(h)
return nil, err
}
// If the pipe is in message mode, return a message byte pipe, which
// supports CloseWrite().
if flags&cPIPE_TYPE_MESSAGE != 0 {
return &win32MessageBytePipe{
win32Pipe: win32Pipe{win32File: f, path: path},
}, nil
}
return &win32Pipe{win32File: f, path: path}, nil
}
type acceptResponse struct {
f *win32File
err error
}
type win32PipeListener struct {
firstHandle syscall.Handle
path string
config PipeConfig
acceptCh chan (chan acceptResponse)
closeCh chan int
doneCh chan int
}
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
path16, err := syscall.UTF16FromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
var oa objectAttributes
oa.Length = unsafe.Sizeof(oa)
var ntPath unicodeString
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
defer localFree(ntPath.Buffer)
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
if first {
if sd != nil {
len := uint32(len(sd))
sdb := localAlloc(0, len)
defer localFree(sdb)
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
} else {
// Construct the default named pipe security descriptor.
var dacl uintptr
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
}
defer localFree(dacl)
sdb := &securityDescriptor{
Revision: 1,
Control: cSE_DACL_PRESENT,
Dacl: dacl,
}
oa.SecurityDescriptor = sdb
}
}
typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS)
if c.MessageMode {
typ |= cFILE_PIPE_MESSAGE_TYPE
}
disposition := uint32(cFILE_OPEN)
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
if first {
disposition = cFILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false.
access = syscall.SYNCHRONIZE
}
timeout := int64(-50 * 10000) // 50ms
var (
h syscall.Handle
iosb ioStatusBlock
)
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
runtime.KeepAlive(ntPath)
return h, nil
}
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
if err != nil {
return nil, err
}
f, err := makeWin32File(h)
if err != nil {
syscall.Close(h)
return nil, err
}
return f, nil
}
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
p, err := l.makeServerPipe()
if err != nil {
return nil, err
}
// Wait for the client to connect.
ch := make(chan error)
go func(p *win32File) {
ch <- connectPipe(p)
}(p)
select {
case err = <-ch:
if err != nil {
p.Close()
p = nil
}
case <-l.closeCh:
// Abort the connect request by closing the handle.
p.Close()
p = nil
err = <-ch
if err == nil || err == ErrFileClosed {
err = ErrPipeListenerClosed
}
}
return p, err
}
func (l *win32PipeListener) listenerRoutine() {
closed := false
for !closed {
select {
case <-l.closeCh:
closed = true
case responseCh := <-l.acceptCh:
var (
p *win32File
err error
)
for {
p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try
// again.
if err != cERROR_NO_DATA {
break
}
}
responseCh <- acceptResponse{p, err}
closed = err == ErrPipeListenerClosed
}
}
syscall.Close(l.firstHandle)
l.firstHandle = 0
// Notify Close() and Accept() callers that the handle has been closed.
close(l.doneCh)
}
// PipeConfig contain configuration for the pipe listener.
type PipeConfig struct {
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
SecurityDescriptor string
// MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in
// this implementation is that CloseWrite() is only supported for message mode pipes;
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
// transferred to the reader (and returned as io.EOF in this implementation)
// when the pipe is in message mode.
MessageMode bool
// InputBufferSize specifies the size the input buffer, in bytes.
InputBufferSize int32
// OutputBufferSize specifies the size the input buffer, in bytes.
OutputBufferSize int32
}
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
// The pipe must not already exist.
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
var (
sd []byte
err error
)
if c == nil {
c = &PipeConfig{}
}
if c.SecurityDescriptor != "" {
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
if err != nil {
return nil, err
}
}
h, err := makeServerPipeHandle(path, sd, c, true)
if err != nil {
return nil, err
}
l := &win32PipeListener{
firstHandle: h,
path: path,
config: *c,
acceptCh: make(chan (chan acceptResponse)),
closeCh: make(chan int),
doneCh: make(chan int),
}
go l.listenerRoutine()
return l, nil
}
func connectPipe(p *win32File) error {
c, err := p.prepareIo()
if err != nil {
return err
}
defer p.wg.Done()
err = connectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != cERROR_PIPE_CONNECTED {
return err
}
return nil
}
func (l *win32PipeListener) Accept() (net.Conn, error) {
ch := make(chan acceptResponse)
select {
case l.acceptCh <- ch:
response := <-ch
err := response.err
if err != nil {
return nil, err
}
if l.config.MessageMode {
return &win32MessageBytePipe{
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
}, nil
}
return &win32Pipe{win32File: response.f, path: l.path}, nil
case <-l.doneCh:
return nil, ErrPipeListenerClosed
}
}
func (l *win32PipeListener) Close() error {
select {
case l.closeCh <- 1:
<-l.doneCh
case <-l.doneCh:
}
return nil
}
func (l *win32PipeListener) Addr() net.Addr {
return pipeAddress(l.path)
}

36
ipc/winpipe/sd.go Normal file
View file

@ -0,0 +1,36 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"unsafe"
)
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
//sys localFree(mem uintptr) = LocalFree
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
//sys getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo
//sys equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid
const (
SE_FILE_OBJECT = 1
OWNER_SECURITY_INFORMATION = 1
)
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
if err != nil {
return nil, err
}
defer localFree(sdBuffer)
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
return sd, nil
}

View file

@ -0,0 +1,290 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winpipe
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modntdll = windows.NewLazySystemDLL("ntdll.dll")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
procCreateFileW = modkernel32.NewProc("CreateFileW")
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
procLocalFree = modkernel32.NewProc("LocalFree")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
procGetSecurityInfo = modadvapi32.NewProc("GetSecurityInfo")
procEqualSid = modadvapi32.NewProc("EqualSid")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
)
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
}
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
}
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
ptr = uintptr(r0)
return
}
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
status = ntstatus(r0)
return
}
func rtlNtStatusToDosError(status ntstatus) (winerr error) {
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
if r0 != 0 {
winerr = syscall.Errno(r0)
}
return
}
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) {
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
status = ntstatus(r0)
return
}
func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
status = ntstatus(r0)
return
}
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(str)
if err != nil {
return
}
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
}
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localFree(mem uintptr) {
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
return
}
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
len = uint32(r0)
return
}
func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) {
r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) {
r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0)
isEqual = r0 != 0
return
}
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
newport = syscall.Handle(r0)
if newport == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
var _p0 uint32
if wait {
_p0 = 1
} else {
_p0 = 0
}
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}

87
main.go
View file

@ -1,8 +1,8 @@
//go:build !windows // +build !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@ -13,12 +13,11 @@ import (
"os/signal" "os/signal"
"runtime" "runtime"
"strconv" "strconv"
"syscall"
"github.com/amnezia-vpn/amneziawg-go/conn" "golang.zx2c4.com/wireguard/device"
"github.com/amnezia-vpn/amneziawg-go/device" "golang.zx2c4.com/wireguard/ipc"
"github.com/amnezia-vpn/amneziawg-go/ipc" "golang.zx2c4.com/wireguard/tun"
"github.com/amnezia-vpn/amneziawg-go/tun"
"golang.org/x/sys/unix"
) )
const ( const (
@ -33,33 +32,32 @@ const (
) )
func printUsage() { func printUsage() {
fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) fmt.Printf("usage:\n")
fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
} }
func warning() { func warning() {
switch runtime.GOOS { if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
case "linux", "freebsd", "openbsd":
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return
}
default:
return return
} }
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │") fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │") fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
fmt.Fprintln(os.Stderr, "│ please visit: │") fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │") fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘") fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
} }
func main() { func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" { if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH) fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", device.WireGuardGoVersion, runtime.GOOS, runtime.GOARCH)
return return
} }
@ -99,19 +97,21 @@ func main() {
logLevel := func() int { logLevel := func() int {
switch os.Getenv("LOG_LEVEL") { switch os.Getenv("LOG_LEVEL") {
case "verbose", "debug": case "debug":
return device.LogLevelVerbose return device.LogLevelDebug
case "info":
return device.LogLevelInfo
case "error": case "error":
return device.LogLevelError return device.LogLevelError
case "silent": case "silent":
return device.LogLevelSilent return device.LogLevelSilent
} }
return device.LogLevelError return device.LogLevelInfo
}() }()
// open TUN device (or use supplied fd) // open TUN device (or use supplied fd)
tdev, err := func() (tun.Device, error) { tun, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD) tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" { if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU) return tun.CreateTUN(interfaceName, device.DefaultMTU)
@ -124,7 +124,7 @@ func main() {
return nil, err return nil, err
} }
err = unix.SetNonblock(int(fd), true) err = syscall.SetNonblock(int(fd), true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,7 +134,7 @@ func main() {
}() }()
if err == nil { if err == nil {
realInterfaceName, err2 := tdev.Name() realInterfaceName, err2 := tun.Name()
if err2 == nil { if err2 == nil {
interfaceName = realInterfaceName interfaceName = realInterfaceName
} }
@ -145,10 +145,12 @@ func main() {
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Verbosef("Starting amneziawg-go version %s", Version) logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled")
if err != nil { if err != nil {
logger.Errorf("Failed to create TUN device: %v", err) logger.Error.Println("Failed to create TUN device:", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
@ -169,8 +171,9 @@ func main() {
return os.NewFile(uintptr(fd), ""), nil return os.NewFile(uintptr(fd), ""), nil
}() }()
if err != nil { if err != nil {
logger.Errorf("UAPI listen error: %v", err) logger.Error.Println("UAPI listen error:", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
return return
} }
@ -196,7 +199,7 @@ func main() {
files[0], // stdin files[0], // stdin
files[1], // stdout files[1], // stdout
files[2], // stderr files[2], // stderr
tdev.File(), tun.File(),
fileUAPI, fileUAPI,
}, },
Dir: ".", Dir: ".",
@ -205,7 +208,7 @@ func main() {
path, err := os.Executable() path, err := os.Executable()
if err != nil { if err != nil {
logger.Errorf("Failed to determine executable: %v", err) logger.Error.Println("Failed to determine executable:", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
@ -215,23 +218,23 @@ func main() {
attr, attr,
) )
if err != nil { if err != nil {
logger.Errorf("Failed to daemonize: %v", err) logger.Error.Println("Failed to daemonize:", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
process.Release() process.Release()
return return
} }
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger) device := device.NewDevice(tun, logger)
logger.Verbosef("Device started") logger.Info.Println("Device started")
errs := make(chan error) errs := make(chan error)
term := make(chan os.Signal, 1) term := make(chan os.Signal, 1)
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
if err != nil { if err != nil {
logger.Errorf("Failed to listen on uapi socket: %v", err) logger.Error.Println("Failed to listen on uapi socket:", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
@ -246,11 +249,11 @@ func main() {
} }
}() }()
logger.Verbosef("UAPI listener started") logger.Info.Println("UAPI listener started")
// wait for program to terminate // wait for program to terminate
signal.Notify(term, unix.SIGTERM) signal.Notify(term, syscall.SIGTERM)
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
select { select {
@ -264,5 +267,5 @@ func main() {
uapi.Close() uapi.Close()
device.Close() device.Close()
logger.Verbosef("Shutting down") logger.Info.Println("Shutting down")
} }

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@ -9,14 +9,12 @@ import (
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
"syscall"
"golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"github.com/amnezia-vpn/amneziawg-go/conn" "golang.zx2c4.com/wireguard/tun"
"github.com/amnezia-vpn/amneziawg-go/device"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/tun"
) )
const ( const (
@ -30,36 +28,33 @@ func main() {
} }
interfaceName := os.Args[1] interfaceName := os.Args[1]
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org") fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
logger := device.NewLogger( logger := device.NewLogger(
device.LogLevelVerbose, device.LogLevelDebug,
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Verbosef("Starting amneziawg-go version %s", Version) logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled")
tun, err := tun.CreateTUN(interfaceName, 0) tun, err := tun.CreateTUN(interfaceName)
if err == nil { if err == nil {
realInterfaceName, err2 := tun.Name() realInterfaceName, err2 := tun.Name()
if err2 == nil { if err2 == nil {
interfaceName = realInterfaceName interfaceName = realInterfaceName
} }
} else { } else {
logger.Errorf("Failed to create TUN device: %v", err) logger.Error.Println("Failed to create TUN device:", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
device := device.NewDevice(tun, conn.NewDefaultBind(), logger) device := device.NewDevice(tun, logger)
err = device.Up() device.Up()
if err != nil { logger.Info.Println("Device started")
logger.Errorf("Failed to bring up device: %v", err)
os.Exit(ExitSetupFailed)
}
logger.Verbosef("Device started")
uapi, err := ipc.UAPIListen(interfaceName) uapi, err := ipc.UAPIListen(interfaceName)
if err != nil { if err != nil {
logger.Errorf("Failed to listen on uapi socket: %v", err) logger.Error.Println("Failed to listen on uapi socket:", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
@ -76,13 +71,13 @@ func main() {
go device.IpcHandle(conn) go device.IpcHandle(conn)
} }
}() }()
logger.Verbosef("UAPI listener started") logger.Info.Println("UAPI listener started")
// wait for program to terminate // wait for program to terminate
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill) signal.Notify(term, os.Kill)
signal.Notify(term, windows.SIGTERM) signal.Notify(term, syscall.SIGTERM)
select { select {
case <-term: case <-term:
@ -95,5 +90,5 @@ func main() {
uapi.Close() uapi.Close()
device.Close() device.Close()
logger.Verbosef("Shutting down") logger.Info.Println("Shutting down")
} }

View file

@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
import ( import (
"net/netip" "net"
"sync" "sync"
"time" "time"
) )
@ -20,22 +20,21 @@ const (
) )
type RatelimiterEntry struct { type RatelimiterEntry struct {
mu sync.Mutex sync.Mutex
lastTime time.Time lastTime time.Time
tokens int64 tokens int64
} }
type Ratelimiter struct { type Ratelimiter struct {
mu sync.RWMutex sync.RWMutex
timeNow func() time.Time stopReset chan struct{}
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
stopReset chan struct{} // send to reset, close to stop tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
table map[netip.Addr]*RatelimiterEntry
} }
func (rate *Ratelimiter) Close() { func (rate *Ratelimiter) Close() {
rate.mu.Lock() rate.Lock()
defer rate.mu.Unlock() defer rate.Unlock()
if rate.stopReset != nil { if rate.stopReset != nil {
close(rate.stopReset) close(rate.stopReset)
@ -43,83 +42,111 @@ func (rate *Ratelimiter) Close() {
} }
func (rate *Ratelimiter) Init() { func (rate *Ratelimiter) Init() {
rate.mu.Lock() rate.Lock()
defer rate.mu.Unlock() defer rate.Unlock()
if rate.timeNow == nil {
rate.timeNow = time.Now
}
// stop any ongoing garbage collection routine // stop any ongoing garbage collection routine
if rate.stopReset != nil { if rate.stopReset != nil {
close(rate.stopReset) close(rate.stopReset)
} }
rate.stopReset = make(chan struct{}) rate.stopReset = make(chan struct{})
rate.table = make(map[netip.Addr]*RatelimiterEntry) rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again. // start garbage collection routine
// Start garbage collection routine.
go func() { go func() {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
ticker.Stop() ticker.Stop()
for { for {
select { select {
case _, ok := <-stopReset: case _, ok := <-rate.stopReset:
ticker.Stop() ticker.Stop()
if !ok { if ok {
ticker = time.NewTicker(time.Second)
} else {
return return
} }
ticker = time.NewTicker(time.Second)
case <-ticker.C: case <-ticker.C:
if rate.cleanup() { func() {
ticker.Stop() rate.Lock()
} defer rate.Unlock()
for key, entry := range rate.tableIPv4 {
entry.Lock()
if time.Since(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key)
}
entry.Unlock()
}
for key, entry := range rate.tableIPv6 {
entry.Lock()
if time.Since(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.Unlock()
}
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
ticker.Stop()
}
}()
} }
} }
}() }()
} }
func (rate *Ratelimiter) cleanup() (empty bool) { func (rate *Ratelimiter) Allow(ip net.IP) bool {
rate.mu.Lock() var entry *RatelimiterEntry
defer rate.mu.Unlock() var keyIPv4 [net.IPv4len]byte
var keyIPv6 [net.IPv6len]byte
for key, entry := range rate.table { // lookup entry
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { IPv4 := ip.To4()
delete(rate.table, key) IPv6 := ip.To16()
}
entry.mu.Unlock() rate.RLock()
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
entry = rate.tableIPv4[keyIPv4]
} else {
copy(keyIPv6[:], IPv6)
entry = rate.tableIPv6[keyIPv6]
} }
return len(rate.table) == 0 rate.RUnlock()
}
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry
// lookup entry
rate.mu.RLock()
entry = rate.table[ip]
rate.mu.RUnlock()
// make new entry if not found // make new entry if not found
if entry == nil { if entry == nil {
entry = new(RatelimiterEntry) entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow() entry.lastTime = time.Now()
rate.mu.Lock() rate.Lock()
rate.table[ip] = entry if IPv4 != nil {
if len(rate.table) == 1 { rate.tableIPv4[keyIPv4] = entry
rate.stopReset <- struct{}{} if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
rate.stopReset <- struct{}{}
}
} else {
rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
} }
rate.mu.Unlock() rate.Unlock()
return true return true
} }
// add tokens to entry // add tokens to entry
entry.mu.Lock()
now := rate.timeNow() entry.Lock()
now := time.Now()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now entry.lastTime = now
if entry.tokens > maxTokens { if entry.tokens > maxTokens {
@ -127,11 +154,12 @@ func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
} }
// subtract cost of packet // subtract cost of packet
if entry.tokens > packetCost { if entry.tokens > packetCost {
entry.tokens -= packetCost entry.tokens -= packetCost
entry.mu.Unlock() entry.Unlock()
return true return true
} }
entry.mu.Unlock() entry.Unlock()
return false return false
} }

View file

@ -1,31 +1,32 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
import ( import (
"net/netip" "net"
"testing" "testing"
"time" "time"
) )
type result struct { type RatelimiterResult struct {
allowed bool allowed bool
text string text string
wait time.Duration wait time.Duration
} }
func TestRatelimiter(t *testing.T) { func TestRatelimiter(t *testing.T) {
var rate Ratelimiter
var expectedResults []result
nano := func(nano int64) time.Duration { var ratelimiter Ratelimiter
var expectedResults []RatelimiterResult
Nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano) return time.Nanosecond * time.Duration(nano)
} }
add := func(res result) { Add := func(res RatelimiterResult) {
expectedResults = append( expectedResults = append(
expectedResults, expectedResults,
res, res,
@ -33,86 +34,69 @@ func TestRatelimiter(t *testing.T) {
} }
for i := 0; i < packetsBurstable; i++ { for i := 0; i < packetsBurstable; i++ {
add(result{ Add(RatelimiterResult{
allowed: true, allowed: true,
text: "initial burst", text: "inital burst",
}) })
} }
add(result{ Add(RatelimiterResult{
allowed: false, allowed: false,
text: "after burst", text: "after burst",
}) })
add(result{ Add(RatelimiterResult{
allowed: true, allowed: true,
wait: nano(time.Second.Nanoseconds() / packetsPerSecond), wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet", text: "filling tokens for single packet",
}) })
add(result{ Add(RatelimiterResult{
allowed: false, allowed: false,
text: "not having refilled enough", text: "not having refilled enough",
}) })
add(result{ Add(RatelimiterResult{
allowed: true, allowed: true,
wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst", text: "filling tokens for two packet burst",
}) })
add(result{ Add(RatelimiterResult{
allowed: true, allowed: true,
text: "second packet in 2 packet burst", text: "second packet in 2 packet burst",
}) })
add(result{ Add(RatelimiterResult{
allowed: false, allowed: false,
text: "packet following 2 packet burst", text: "packet following 2 packet burst",
}) })
ips := []netip.Addr{ ips := []net.IP{
netip.MustParseAddr("127.0.0.1"), net.ParseIP("127.0.0.1"),
netip.MustParseAddr("192.168.1.1"), net.ParseIP("192.168.1.1"),
netip.MustParseAddr("172.167.2.3"), net.ParseIP("172.167.2.3"),
netip.MustParseAddr("97.231.252.215"), net.ParseIP("97.231.252.215"),
netip.MustParseAddr("248.97.91.167"), net.ParseIP("248.97.91.167"),
netip.MustParseAddr("188.208.233.47"), net.ParseIP("188.208.233.47"),
netip.MustParseAddr("104.2.183.179"), net.ParseIP("104.2.183.179"),
netip.MustParseAddr("72.129.46.120"), net.ParseIP("72.129.46.120"),
netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
} }
now := time.Now() ratelimiter.Init()
rate.timeNow = func() time.Time {
return now
}
defer func() {
// Lock to avoid data race with cleanup goroutine from Init.
rate.mu.Lock()
defer rate.mu.Unlock()
rate.timeNow = time.Now
}()
timeSleep := func(d time.Duration) {
now = now.Add(d + 1)
rate.cleanup()
}
rate.Init()
defer rate.Close()
for i, res := range expectedResults { for i, res := range expectedResults {
timeSleep(res.wait) time.Sleep(res.wait)
for _, ip := range ips { for _, ip := range ips {
allowed := rate.Allow(ip) allowed := ratelimiter.Allow(ip)
if allowed != res.allowed { if allowed != res.allowed {
t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
} }
} }
} }

View file

@ -1,62 +1,83 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
package replay package replay
type block uint64 /* Implementation of RFC6479
* https://tools.ietf.org/html/rfc6479
*
* The implementation is not safe for concurrent use!
*/
const ( const (
blockBitLog = 6 // 1<<6 == 64 bits // See: https://golang.org/src/math/big/arith.go
blockBits = 1 << blockBitLog // must be power of 2 _Wordm = ^uintptr(0)
ringBlocks = 1 << 7 // must be power of 2 _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
windowSize = (ringBlocks - 1) * blockBits _WordSize = 1 << _WordLogSize
blockMask = ringBlocks - 1
bitMask = blockBits - 1
) )
// A Filter rejects replayed messages by checking if message counter value is const (
// within a sliding window of previously received messages. CounterRedundantBitsLog = _WordLogSize + 3
// The zero value for Filter is an empty filter ready to use. CounterRedundantBits = _WordSize * 8
// Filters are unsafe for concurrent use. CounterBitsTotal = 2048
type Filter struct { CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
last uint64 )
ring [ringBlocks]block
const (
BacktrackWords = CounterBitsTotal / _WordSize
)
func minUint64(a uint64, b uint64) uint64 {
if a > b {
return b
}
return a
} }
// Reset resets the filter to empty state. type ReplayFilter struct {
func (f *Filter) Reset() { counter uint64
f.last = 0 backtrack [BacktrackWords]uintptr
f.ring[0] = 0
} }
// ValidateCounter checks if the counter should be accepted. func (filter *ReplayFilter) Init() {
// Overlimit counters (>= limit) are always rejected. filter.counter = 0
func (f *Filter) ValidateCounter(counter, limit uint64) bool { filter.backtrack[0] = 0
}
func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
if counter >= limit { if counter >= limit {
return false return false
} }
indexBlock := counter >> blockBitLog
if counter > f.last { // move window forward indexWord := counter >> CounterRedundantBitsLog
current := f.last >> blockBitLog
diff := indexBlock - current if counter > filter.counter {
if diff > ringBlocks {
diff = ringBlocks // cap diff to clear the whole ring // move window forward
current := filter.counter >> CounterRedundantBitsLog
diff := minUint64(indexWord-current, BacktrackWords)
for i := uint64(1); i <= diff; i++ {
filter.backtrack[(current+i)%BacktrackWords] = 0
} }
for i := current + 1; i <= current+diff; i++ { filter.counter = counter
f.ring[i&blockMask] = 0
} } else if filter.counter-counter > CounterWindowSize {
f.last = counter
} else if f.last-counter > windowSize { // behind current window // behind current window
return false return false
} }
indexWord %= BacktrackWords
indexBit := counter & uint64(CounterRedundantBits-1)
// check and set bit // check and set bit
indexBlock &= blockMask
indexBit := counter & bitMask oldValue := filter.backtrack[indexWord]
old := f.ring[indexBlock] newValue := oldValue | (1 << indexBit)
new := old | 1<<indexBit filter.backtrack[indexWord] = newValue
f.ring[indexBlock] = new return oldValue != newValue
return old != new
} }

Some files were not shown because too many files have changed in this diff Show more