From c3bc566975b354093cc66e40a2d0cecf89368f5d Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 11 Sep 2023 12:00:47 +0200 Subject: [PATCH] prepare async config read/write Signed-off-by: Mark Puha --- device/device.go | 241 ++++++++++++++++++++++++++++++++++------------ device/receive.go | 5 + device/uapi.go | 182 +++++++--------------------------- go.mod | 1 + go.sum | 2 + 5 files changed, 225 insertions(+), 206 deletions(-) diff --git a/device/device.go b/device/device.go index 7405c1c..7be97ff 100644 --- a/device/device.go +++ b/device/device.go @@ -6,13 +6,14 @@ package device import ( - "log" "runtime" "sync" "sync/atomic" "time" + "github.com/tevino/abool/v2" "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" @@ -90,18 +91,22 @@ type Device struct { ipcMutex sync.RWMutex closed chan struct{} log *Logger - aSecCfg struct { - isOn bool - junkPacketCount int - junkPacketMinSize int - junkPacketMaxSize int - initPacketJunkSize int - responsePacketJunkSize int - initPacketMagicHeader uint32 - responsePacketMagicHeader uint32 - underloadPacketMagicHeader uint32 - transportPacketMagicHeader uint32 - } + + isASecOn abool.AtomicBool + aSecMux sync.RWMutex + aSecCfg aSecCfgType +} + +type aSecCfgType struct { + junkPacketCount int + junkPacketMinSize int + junkPacketMaxSize int + initPacketJunkSize int + responsePacketJunkSize int + initPacketMagicHeader uint32 + responsePacketMagicHeader uint32 + underloadPacketMagicHeader uint32 + transportPacketMagicHeader uint32 } // deviceState represents the state of a Device. @@ -572,56 +577,174 @@ func (device *Device) BindClose() error { return err } func (device *Device) isAdvancedSecurityOn() bool { - return device.aSecCfg.isOn + return device.isASecOn.IsSet() } -func (device *Device) handlePostConfig() { - if device.isAdvancedSecurityOn() { - if device.aSecCfg.junkPacketMaxSize >= 0 { - if device.aSecCfg.junkPacketMaxSize == device.aSecCfg.junkPacketMinSize { - device.aSecCfg.junkPacketMaxSize++ // to make rand gen work - } else if device.aSecCfg.junkPacketMaxSize < device.aSecCfg.junkPacketMinSize { - log.Fatalf( - "MaxSize: %d; should be greater than MinSize: %d", - device.aSecCfg.junkPacketMaxSize, - device.aSecCfg.junkPacketMinSize, +func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { + isASecOn := false + device.aSecMux.Lock() + if tempASecCfg.junkPacketCount < 0 { + err = ipcErrorf( + ipc.IpcErrorInvalid, + "JunkPacketCount should be non negative", + ) + } else if tempASecCfg.junkPacketCount > 0 { + device.log.Verbosef("UAPI: Updating junk_packet_count") + device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount + isASecOn = true + } + if tempASecCfg.junkPacketMinSize != 0 { + device.log.Verbosef("UAPI: Updating junk_packet_min_size") + device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize + isASecOn = true + } + if tempASecCfg.junkPacketMaxSize != 0 { + if tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { + tempASecCfg.junkPacketMaxSize++ // to make rand gen work + } + if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize{ + device.aSecCfg.junkPacketMinSize = 0 + device.aSecCfg.junkPacketMaxSize = 1 + if err != nil { + err = ipcErrorf( + ipc.IpcErrorInvalid, + "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", + tempASecCfg.junkPacketMaxSize, + MaxSegmentSize, + err, + ) + } else { + err = ipcErrorf( + ipc.IpcErrorInvalid, + "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", + tempASecCfg.junkPacketMaxSize, + MaxSegmentSize, ) } - } - - if device.aSecCfg.initPacketMagicHeader != 0 && - device.aSecCfg.initPacketMagicHeader != 1 { - - MessageInitiationType = device.aSecCfg.initPacketMagicHeader - } - if device.aSecCfg.responsePacketMagicHeader != 0 && - device.aSecCfg.responsePacketMagicHeader != 1 { - - MessageResponseType = device.aSecCfg.responsePacketMagicHeader - } - if device.aSecCfg.underloadPacketMagicHeader != 0 && - device.aSecCfg.underloadPacketMagicHeader != 1 { - - MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader - } - if device.aSecCfg.transportPacketMagicHeader != 0 && - device.aSecCfg.transportPacketMagicHeader != 1 { - - MessageTransportType = device.aSecCfg.transportPacketMagicHeader - } - - packetSizeToMsgType = map[int]uint32{ - MessageInitiationSize + device.aSecCfg.initPacketJunkSize: MessageInitiationType, - MessageResponseSize + device.aSecCfg.responsePacketJunkSize: MessageResponseType, - MessageCookieReplySize: MessageCookieReplyType, - MessageTransportSize: MessageTransportType, - } - - msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.aSecCfg.initPacketJunkSize, - MessageResponseType: device.aSecCfg.responsePacketJunkSize, - MessageCookieReplyType: 0, - MessageTransportType: 0, + } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize { + if err != nil { + err = ipcErrorf( + ipc.IpcErrorInvalid, + "maxSize: %d; should be greater than minSize: %d; %w", + tempASecCfg.junkPacketMaxSize, + tempASecCfg.junkPacketMinSize, + err, + ) + } else { + err = ipcErrorf( + ipc.IpcErrorInvalid, + "maxSize: %d; should be greater than minSize: %d", + tempASecCfg.junkPacketMaxSize, + tempASecCfg.junkPacketMinSize, + ) + } + } else { + device.log.Verbosef("UAPI: Updating junk_packet_max_size") + device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize + isASecOn = true } } + if tempASecCfg.initPacketJunkSize != 0 { + if 148+tempASecCfg.initPacketJunkSize >= MaxSegmentSize { + if err != nil { + err = ipcErrorf( + ipc.IpcErrorInvalid, + `init header size(148) + junkSize:%d; + should be smaller than maxSegmentSize: %d; %w`, + tempASecCfg.initPacketJunkSize, + MaxSegmentSize, + err, + ) + } else { + err = ipcErrorf( + ipc.IpcErrorInvalid, + `init header size(148) + junkSize:%d; + should be smaller than maxSegmentSize: %d`, + tempASecCfg.initPacketJunkSize, + MaxSegmentSize, + ) + } + } else { + device.log.Verbosef("UAPI: Updating init_packet_junk_size") + device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize + isASecOn = true + } + } + if tempASecCfg.responsePacketJunkSize != 0 { + if 92+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize { + if err != nil { + err = ipcErrorf( + ipc.IpcErrorInvalid, + `response header size(92) + junkSize:%d; + should be smaller than maxSegmentSize: %d; %w`, + tempASecCfg.responsePacketJunkSize, + MaxSegmentSize, + err, + ) + } else { + err = ipcErrorf( + ipc.IpcErrorInvalid, + `response header size(92) + junkSize:%d; + should be smaller than maxSegmentSize: %d`, + tempASecCfg.responsePacketJunkSize, + MaxSegmentSize, + ) + } + } else { + device.log.Verbosef("UAPI: Updating response_packet_junk_size") + device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize + isASecOn = true + } + } + + if device.aSecCfg.initPacketMagicHeader > 4 { + isASecOn = true + device.log.Verbosef("UAPI: Updating init_packet_magic_header") + device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader + MessageInitiationType = device.aSecCfg.initPacketMagicHeader + } else { + MessageInitiationType = 1 + } + if device.aSecCfg.responsePacketMagicHeader > 4 { + isASecOn = true + device.log.Verbosef("UAPI: Updating response_packet_magic_header") + device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader + MessageResponseType = device.aSecCfg.responsePacketMagicHeader + } else { + MessageResponseType = 2 + } + if device.aSecCfg.underloadPacketMagicHeader > 4 { + isASecOn = true + device.log.Verbosef("UAPI: Updating underload_packet_magic_header") + device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader + MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader + } else { + MessageCookieReplyType = 3 + } + if device.aSecCfg.transportPacketMagicHeader > 4 { + isASecOn = true + device.log.Verbosef("UAPI: Updating transport_packet_magic_header") + device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader + MessageTransportType = device.aSecCfg.transportPacketMagicHeader + } else { + MessageTransportType = 4 + } + + packetSizeToMsgType = map[int]uint32{ + MessageInitiationSize + device.aSecCfg.initPacketJunkSize: MessageInitiationType, + MessageResponseSize + device.aSecCfg.responsePacketJunkSize: MessageResponseType, + MessageCookieReplySize: MessageCookieReplyType, + MessageTransportSize: MessageTransportType, + } + + msgTypeToJunkSize = map[uint32]int{ + MessageInitiationType: device.aSecCfg.initPacketJunkSize, + MessageResponseType: device.aSecCfg.responsePacketJunkSize, + MessageCookieReplyType: 0, + MessageTransportType: 0, + } + device.isASecOn.SetTo(isASecOn) + device.aSecMux.Unlock() + + return nil } diff --git a/device/receive.go b/device/receive.go index 4714175..b4d7e31 100644 --- a/device/receive.go +++ b/device/receive.go @@ -132,6 +132,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 + device.aSecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -234,6 +235,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } + device.aSecMux.RUnlock() for peer, elems := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elems @@ -292,6 +294,8 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { + device.aSecMux.RLock() + // handle cookie fields and ratelimiting switch elem.msgType { @@ -455,6 +459,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: + device.aSecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/uapi.go b/device/uapi.go index f6a0e5b..4f224fa 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "log" "net" "net/netip" "strconv" @@ -189,6 +188,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true + tempASecCfg := aSecCfgType{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() @@ -221,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { var err error if deviceConfig { - err = device.handleDeviceLine(key, value) + err = device.handleDeviceLine(key, value, &tempASecCfg) } else { err = device.handlePeerLine(peer, key, value) } @@ -229,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } - device.handlePostConfig() + device.handlePostConfig(&tempASecCfg) peer.handlePostConfig() if err := scanner.Err(); err != nil { @@ -238,17 +238,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string) error { +func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { switch key { case "private_key": var sk NoisePrivateKey err := sk.FromMaybeZeroHex(value) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "failed to set private_key: %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"failed to set private_key: %w",err) } device.log.Verbosef("UAPI: Updating private key") device.SetPrivateKey(sk) @@ -256,11 +252,7 @@ func (device *Device) handleDeviceLine(key, value string) error { case "listen_port": port, err := strconv.ParseUint(value, 10, 16) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "failed to parse listen_port: %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"failed to parse listen_port: %w",err) } // update port and rebind @@ -271,11 +263,7 @@ func (device *Device) handleDeviceLine(key, value string) error { device.net.Unlock() if err := device.BindUpdate(); err != nil { - return ipcErrorf( - ipc.IpcErrorPortInUse, - "failed to set listen_port: %w", - err, - ) + return ipcErrorf(ipc.IpcErrorPortInUse,"failed to set listen_port: %w",err) } case "fwmark": @@ -286,180 +274,80 @@ func (device *Device) handleDeviceLine(key, value string) error { device.log.Verbosef("UAPI: Updating fwmark") if err := device.BindSetMark(uint32(mark)); err != nil { - return ipcErrorf( - ipc.IpcErrorPortInUse, - "failed to update fwmark: %w", - err, - ) + return ipcErrorf(ipc.IpcErrorPortInUse,"failed to update fwmark: %w", err) } case "replace_peers": if value != "true" { - return ipcErrorf( - ipc.IpcErrorInvalid, - "failed to set replace_peers, invalid value: %v", - value, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"failed to set replace_peers, invalid value: %v", value) } device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() - + case "jc": junkPacketCount, err := strconv.Atoi(value) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse junk_packet_count %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) } + tempASecCfg.junkPacketCount = junkPacketCount - if junkPacketCount < 0 { - log.Fatalf("JunkPacketCount should be non negative") - } - device.log.Verbosef("UAPI: Updating junk_packet_count") - device.aSecCfg.isOn = true - device.aSecCfg.junkPacketCount = junkPacketCount - case "jmin": junkPacketMinSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse junk_packet_min_size %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse junk_packet_min_size %w", err) } - device.log.Verbosef("UAPI: Updating junk_packet_min_size") - device.aSecCfg.isOn = true - device.aSecCfg.junkPacketMinSize = junkPacketMinSize - + tempASecCfg.junkPacketMinSize = junkPacketMinSize + case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse junk_packet_max_size %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse junk_packet_max_size %w", err) } - if junkPacketMaxSize >= MaxSegmentSize { - log.Fatalf( - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - junkPacketMaxSize, - MaxSegmentSize, - ) - } - device.log.Verbosef("UAPI: Updating junk_packet_max_size") - device.aSecCfg.isOn = true - device.aSecCfg.junkPacketMaxSize = junkPacketMaxSize - + tempASecCfg.junkPacketMaxSize = junkPacketMaxSize + case "s1": initPacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse init_packet_junk_size %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse init_packet_junk_size %w", err) } - if 148+initPacketJunkSize >= MaxSegmentSize { - log.Fatalf( - `init header size(148) + junkSize:%d; - should be smaller than maxSegmentSize: %d`, - initPacketJunkSize, - MaxSegmentSize, - ) - } - device.log.Verbosef("UAPI: Updating init_packet_junk_size") - device.aSecCfg.isOn = true - device.aSecCfg.initPacketJunkSize = initPacketJunkSize - + tempASecCfg.initPacketJunkSize = initPacketJunkSize + case "s2": responsePacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse response_packet_junk_size %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse response_packet_junk_size %w", err) } - if 92+responsePacketJunkSize >= MaxSegmentSize { - log.Fatalf( - `response header size(92) + junkSize:%d; - should be smaller than maxSegmentSize: %d`, - responsePacketJunkSize, - MaxSegmentSize, - ) - } - device.log.Verbosef("UAPI: Updating response_packet_junk_size") - device.aSecCfg.isOn = true - device.aSecCfg.responsePacketJunkSize = responsePacketJunkSize + tempASecCfg.responsePacketJunkSize = responsePacketJunkSize - case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse init_packet_magic_header %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse init_packet_magic_header %w", err) } - device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.aSecCfg.isOn = true - device.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - + tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) + case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse response_packet_magic_header %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse response_packet_magic_header %w", err) } - device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.aSecCfg.isOn = true - device.aSecCfg.responsePacketMagicHeader = uint32( - responsePacketMagicHeader, - ) - + tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) + case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse underload_packet_magic_header %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse underload_packet_magic_header %w", err) } - device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.aSecCfg.isOn = true - device.aSecCfg.underloadPacketMagicHeader = uint32( - underloadPacketMagicHeader, - ) - + tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf( - ipc.IpcErrorInvalid, - "faield to parse transport_packet_magic_header %w", - err, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"faield to parse transport_packet_magic_header %w", err) } - device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.aSecCfg.isOn = true - device.aSecCfg.transportPacketMagicHeader = uint32( - transportPacketMagicHeader, - ) + tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) default: - return ipcErrorf( - ipc.IpcErrorInvalid, - "invalid UAPI device key: %v", - key, - ) + return ipcErrorf(ipc.IpcErrorInvalid,"invalid UAPI device key: %v",key) } return nil diff --git a/go.mod b/go.mod index c04e1bb..9616041 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module golang.zx2c4.com/wireguard go 1.20 require ( + github.com/tevino/abool/v2 v2.1.0 golang.org/x/crypto v0.6.0 golang.org/x/net v0.7.0 golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 diff --git a/go.sum b/go.sum index cfeaee6..3707808 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= +github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=