prepare async config read/write

Signed-off-by: Mark Puha <marko10@inf.elte.hu>
This commit is contained in:
Mark Puha 2023-09-11 12:00:47 +02:00
parent 0be1878d38
commit c3bc566975
5 changed files with 225 additions and 206 deletions

View file

@ -6,13 +6,14 @@
package device
import (
"log"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/tevino/abool/v2"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
@ -90,18 +91,22 @@ type Device struct {
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
aSecCfg struct {
isOn bool
junkPacketCount int
junkPacketMinSize int
junkPacketMaxSize int
initPacketJunkSize int
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
}
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
}
type aSecCfgType struct {
junkPacketCount int
junkPacketMinSize int
junkPacketMaxSize int
initPacketJunkSize int
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
}
// deviceState represents the state of a Device.
@ -572,56 +577,174 @@ func (device *Device) BindClose() error {
return err
}
func (device *Device) isAdvancedSecurityOn() bool {
return device.aSecCfg.isOn
return device.isASecOn.IsSet()
}
func (device *Device) handlePostConfig() {
if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketMaxSize >= 0 {
if device.aSecCfg.junkPacketMaxSize == device.aSecCfg.junkPacketMinSize {
device.aSecCfg.junkPacketMaxSize++ // to make rand gen work
} else if device.aSecCfg.junkPacketMaxSize < device.aSecCfg.junkPacketMinSize {
log.Fatalf(
"MaxSize: %d; should be greater than MinSize: %d",
device.aSecCfg.junkPacketMaxSize,
device.aSecCfg.junkPacketMinSize,
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
isASecOn := false
device.aSecMux.Lock()
if tempASecCfg.junkPacketCount < 0 {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
)
} else if tempASecCfg.junkPacketCount > 0 {
device.log.Verbosef("UAPI: Updating junk_packet_count")
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
isASecOn = true
}
if tempASecCfg.junkPacketMinSize != 0 {
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
isASecOn = true
}
if tempASecCfg.junkPacketMaxSize != 0 {
if tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
}
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize{
device.aSecCfg.junkPacketMinSize = 0
device.aSecCfg.junkPacketMaxSize = 1
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempASecCfg.junkPacketMaxSize,
MaxSegmentSize,
)
}
}
if device.aSecCfg.initPacketMagicHeader != 0 &&
device.aSecCfg.initPacketMagicHeader != 1 {
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
}
if device.aSecCfg.responsePacketMagicHeader != 0 &&
device.aSecCfg.responsePacketMagicHeader != 1 {
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
}
if device.aSecCfg.underloadPacketMagicHeader != 0 &&
device.aSecCfg.underloadPacketMagicHeader != 1 {
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
}
if device.aSecCfg.transportPacketMagicHeader != 0 &&
device.aSecCfg.transportPacketMagicHeader != 1 {
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
}
packetSizeToMsgType = map[int]uint32{
MessageInitiationSize + device.aSecCfg.initPacketJunkSize: MessageInitiationType,
MessageResponseSize + device.aSecCfg.responsePacketJunkSize: MessageResponseType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempASecCfg.junkPacketMaxSize,
tempASecCfg.junkPacketMinSize,
)
}
} else {
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
isASecOn = true
}
}
if tempASecCfg.initPacketJunkSize != 0 {
if 148+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d;
should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d;
should be smaller than maxSegmentSize: %d`,
tempASecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
isASecOn = true
}
}
if tempASecCfg.responsePacketJunkSize != 0 {
if 92+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d;
should be smaller than maxSegmentSize: %d; %w`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d;
should be smaller than maxSegmentSize: %d`,
tempASecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else {
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
isASecOn = true
}
}
if device.aSecCfg.initPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
} else {
MessageInitiationType = 1
}
if device.aSecCfg.responsePacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
} else {
MessageResponseType = 2
}
if device.aSecCfg.underloadPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
} else {
MessageCookieReplyType = 3
}
if device.aSecCfg.transportPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
} else {
MessageTransportType = 4
}
packetSizeToMsgType = map[int]uint32{
MessageInitiationSize + device.aSecCfg.initPacketJunkSize: MessageInitiationType,
MessageResponseSize + device.aSecCfg.responsePacketJunkSize: MessageResponseType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
device.isASecOn.SetTo(isASecOn)
device.aSecMux.Unlock()
return nil
}

View file

@ -132,6 +132,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
device.aSecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@ -234,6 +235,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
device.aSecMux.RUnlock()
for peer, elems := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elems
@ -292,6 +294,8 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
device.aSecMux.RLock()
// handle cookie fields and ratelimiting
switch elem.msgType {
@ -455,6 +459,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}

View file

@ -11,7 +11,6 @@ import (
"errors"
"fmt"
"io"
"log"
"net"
"net/netip"
"strconv"
@ -189,6 +188,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
tempASecCfg := aSecCfgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
@ -221,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value)
err = device.handleDeviceLine(key, value, &tempASecCfg)
} else {
err = device.handlePeerLine(peer, key, value)
}
@ -229,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
device.handlePostConfig()
device.handlePostConfig(&tempASecCfg)
peer.handlePostConfig()
if err := scanner.Err(); err != nil {
@ -238,17 +238,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
func (device *Device) handleDeviceLine(key, value string) error {
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
err := sk.FromMaybeZeroHex(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set private_key: %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"failed to set private_key: %w",err)
}
device.log.Verbosef("UAPI: Updating private key")
device.SetPrivateKey(sk)
@ -256,11 +252,7 @@ func (device *Device) handleDeviceLine(key, value string) error {
case "listen_port":
port, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to parse listen_port: %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"failed to parse listen_port: %w",err)
}
// update port and rebind
@ -271,11 +263,7 @@ func (device *Device) handleDeviceLine(key, value string) error {
device.net.Unlock()
if err := device.BindUpdate(); err != nil {
return ipcErrorf(
ipc.IpcErrorPortInUse,
"failed to set listen_port: %w",
err,
)
return ipcErrorf(ipc.IpcErrorPortInUse,"failed to set listen_port: %w",err)
}
case "fwmark":
@ -286,180 +274,80 @@ func (device *Device) handleDeviceLine(key, value string) error {
device.log.Verbosef("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(mark)); err != nil {
return ipcErrorf(
ipc.IpcErrorPortInUse,
"failed to update fwmark: %w",
err,
)
return ipcErrorf(ipc.IpcErrorPortInUse,"failed to update fwmark: %w", err)
}
case "replace_peers":
if value != "true" {
return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set replace_peers, invalid value: %v",
value,
)
return ipcErrorf(ipc.IpcErrorInvalid,"failed to set replace_peers, invalid value: %v", value)
}
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse junk_packet_count %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
tempASecCfg.junkPacketCount = junkPacketCount
if junkPacketCount < 0 {
log.Fatalf("JunkPacketCount should be non negative")
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
device.aSecCfg.isOn = true
device.aSecCfg.junkPacketCount = junkPacketCount
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse junk_packet_min_size %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
device.aSecCfg.isOn = true
device.aSecCfg.junkPacketMinSize = junkPacketMinSize
tempASecCfg.junkPacketMinSize = junkPacketMinSize
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse junk_packet_max_size %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse junk_packet_max_size %w", err)
}
if junkPacketMaxSize >= MaxSegmentSize {
log.Fatalf(
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
junkPacketMaxSize,
MaxSegmentSize,
)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
device.aSecCfg.isOn = true
device.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse init_packet_junk_size %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse init_packet_junk_size %w", err)
}
if 148+initPacketJunkSize >= MaxSegmentSize {
log.Fatalf(
`init header size(148) + junkSize:%d;
should be smaller than maxSegmentSize: %d`,
initPacketJunkSize,
MaxSegmentSize,
)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
device.aSecCfg.isOn = true
device.aSecCfg.initPacketJunkSize = initPacketJunkSize
tempASecCfg.initPacketJunkSize = initPacketJunkSize
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse response_packet_junk_size %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse response_packet_junk_size %w", err)
}
if 92+responsePacketJunkSize >= MaxSegmentSize {
log.Fatalf(
`response header size(92) + junkSize:%d;
should be smaller than maxSegmentSize: %d`,
responsePacketJunkSize,
MaxSegmentSize,
)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
device.aSecCfg.isOn = true
device.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse init_packet_magic_header %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse init_packet_magic_header %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.isOn = true
device.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse response_packet_magic_header %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse response_packet_magic_header %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.isOn = true
device.aSecCfg.responsePacketMagicHeader = uint32(
responsePacketMagicHeader,
)
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse underload_packet_magic_header %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse underload_packet_magic_header %w", err)
}
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.isOn = true
device.aSecCfg.underloadPacketMagicHeader = uint32(
underloadPacketMagicHeader,
)
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse transport_packet_magic_header %w",
err,
)
return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse transport_packet_magic_header %w", err)
}
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.isOn = true
device.aSecCfg.transportPacketMagicHeader = uint32(
transportPacketMagicHeader,
)
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
default:
return ipcErrorf(
ipc.IpcErrorInvalid,
"invalid UAPI device key: %v",
key,
)
return ipcErrorf(ipc.IpcErrorInvalid,"invalid UAPI device key: %v",key)
}
return nil

1
go.mod
View file

@ -3,6 +3,7 @@ module golang.zx2c4.com/wireguard
go 1.20
require (
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.6.0
golang.org/x/net v0.7.0
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89

2
go.sum
View file

@ -1,5 +1,7 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=