From c5312e274005922c63cbefc54aef54540f4ce7e2 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 7 Jul 2025 20:29:46 +0200 Subject: [PATCH] feat: continue range h1-h4 --- device/awg/awg.go | 87 +++++++++-------- device/device.go | 221 ++++++++++++++++++++++++------------------ device/device_test.go | 8 +- device/receive.go | 35 +------ device/uapi.go | 7 +- go.mod | 2 +- go.sum | 14 ++- 7 files changed, 196 insertions(+), 178 deletions(-) diff --git a/device/awg/awg.go b/device/awg/awg.go index fd5a96d..838e7ca 100644 --- a/device/awg/awg.go +++ b/device/awg/awg.go @@ -12,22 +12,23 @@ import ( ) type aSecCfgType struct { - IsSet bool - JunkPacketCount int - JunkPacketMinSize int - JunkPacketMaxSize int - InitHeaderJunkSize int - ResponseHeaderJunkSize int - CookieReplyHeaderJunkSize int - TransportHeaderJunkSize int - InitPacketMagicHeader uint32 - ResponsePacketMagicHeader uint32 - UnderloadPacketMagicHeader uint32 - TransportPacketMagicHeader uint32 - // InitPacketMagicHeader Limit - // ResponsePacketMagicHeader Limit - // UnderloadPacketMagicHeader Limit - // TransportPacketMagicHeader Limit + IsSet bool + JunkPacketCount int + JunkPacketMinSize int + JunkPacketMaxSize int + InitHeaderJunkSize int + ResponseHeaderJunkSize int + CookieReplyHeaderJunkSize int + TransportHeaderJunkSize int + // InitPacketMagicHeader uint32 + // ResponsePacketMagicHeader uint32 + // UnderloadPacketMagicHeader uint32 + // TransportPacketMagicHeader uint32 + + InitPacketMagicHeader Limit + ResponsePacketMagicHeader Limit + UnderloadPacketMagicHeader Limit + TransportPacketMagicHeader Limit } type Limit struct { @@ -49,31 +50,29 @@ func NewLimit(min, max, headerType uint32) (Limit, error) { } func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) { - // tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType) - // var min, max, headerType uint32 - // _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType) - // if err != nil { - // return Limit{}, fmt.Errorf("invalid magic header format: %s", value) - // } + splitLimits := strings.Split(value, "-") + if len(splitLimits) != 2 { + magicHeader, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return Limit{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err) + } - limits := strings.Split(value, "-") - if len(limits) != 2 { - return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value) + return NewLimit(uint32(magicHeader), uint32(magicHeader), defaultHeaderType) } - min, err := strconv.ParseUint(limits[0], 10, 32) + min, err := strconv.ParseUint(splitLimits[0], 10, 32) if err != nil { - return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err) + return Limit{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err) } - max, err := strconv.ParseUint(limits[1], 10, 32) + max, err := strconv.ParseUint(splitLimits[1], 10, 32) if err != nil { - return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err) + return Limit{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err) } limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType) if err != nil { - return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err) + return Limit{}, fmt.Errorf("new limit key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err) } return limit, nil @@ -81,7 +80,7 @@ func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error type Limits []Limit -func NewLimits(limits []Limit) Limits { +func NewLimits(limits ...Limit) Limits { slices.SortFunc(limits, func(a, b Limit) int { if a.Min < b.Min { return -1 @@ -102,32 +101,30 @@ type Protocol struct { JunkCreator junkCreator HandshakeHandler SpecialHandshakeHandler + + limits Limits } func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { - return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize) + return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0) } func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) { - return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize) + return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0) } func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) { - return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize) + return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0) } func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) { return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize) } -func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) { - extraSize := 0 - if len(optExtraSize) == 1 { - extraSize = optExtraSize[0] - } - +func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) { var junk []byte protocol.ASecMux.RLock() + if junkSize != 0 { buf := make([]byte, 0, junkSize+extraSize) writer := bytes.NewBuffer(buf[:0]) @@ -142,3 +139,13 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([ return junk, nil } + +func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) { + for _, limit := range protocol.limits { + if limit.Min <= msgType && msgType <= limit.Max { + return limit.Min, nil + } + } + + return 0, fmt.Errorf("no limit found for message type: %d", msgType) +} diff --git a/device/device.go b/device/device.go index 7261f4a..a0e04cb 100644 --- a/device/device.go +++ b/device/device.go @@ -6,7 +6,9 @@ package device import ( + "encoding/binary" "errors" + "fmt" "runtime" "sync" "sync/atomic" @@ -646,6 +648,66 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { isASecOn = true } + if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 { + isASecOn = true + device.log.Verbosef("UAPI: Updating init_packet_magic_header") + device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader + MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader.Min + } else { + device.log.Verbosef("UAPI: Using default init type") + MessageInitiationType = DefaultMessageInitiationType + } + + if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 { + isASecOn = true + device.log.Verbosef("UAPI: Updating response_packet_magic_header") + device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader + MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader.Min + } else { + device.log.Verbosef("UAPI: Using default response type") + MessageResponseType = DefaultMessageResponseType + } + + if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 { + isASecOn = true + device.log.Verbosef("UAPI: Updating underload_packet_magic_header") + device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader + MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader.Min + } else { + device.log.Verbosef("UAPI: Using default underload type") + MessageCookieReplyType = DefaultMessageCookieReplyType + } + + if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 { + isASecOn = true + device.log.Verbosef("UAPI: Updating transport_packet_magic_header") + device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader + MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader.Min + } else { + device.log.Verbosef("UAPI: Using default transport type") + MessageTransportType = DefaultMessageTransportType + } + + isSameHeaderMap := map[uint32]struct{}{ + MessageInitiationType: {}, + MessageResponseType: {}, + MessageCookieReplyType: {}, + MessageTransportType: {}, + } + + // size will be different if same values + if len(isSameHeaderMap) != 4 { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, + MessageInitiationType, + MessageResponseType, + MessageCookieReplyType, + MessageTransportType, + ), + ) + } + newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize if newInitSize >= MaxSegmentSize { @@ -692,7 +754,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { MaxSegmentSize, ), ) -<<<<<<< HEAD } else { device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize } @@ -726,100 +787,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { newTransportSize: {}, } -======= - } else { - device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize - } - - if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 { - isASecOn = true - } - - newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize - - if newTransportSize >= MaxSegmentSize { - errs = append(errs, ipcErrorf( - ipc.IpcErrorInvalid, - `transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.ASecCfg.TransportHeaderJunkSize, - MaxSegmentSize, - ), - ) - } else { - device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize - } - - if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 { - isASecOn = true - } - - if tempAwg.ASecCfg.InitPacketMagicHeader > 4 { - isASecOn = true - device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader - MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader - } else { - device.log.Verbosef("UAPI: Using default init type") - MessageInitiationType = DefaultMessageInitiationType - } - - if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 { - isASecOn = true - device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader - MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader - } else { - device.log.Verbosef("UAPI: Using default response type") - MessageResponseType = DefaultMessageResponseType - } - - if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 { - isASecOn = true - device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader - MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader - } else { - device.log.Verbosef("UAPI: Using default underload type") - MessageCookieReplyType = DefaultMessageCookieReplyType - } - - if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 { - isASecOn = true - device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader - MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader - } else { - device.log.Verbosef("UAPI: Using default transport type") - MessageTransportType = DefaultMessageTransportType - } - - isSameHeaderMap := map[uint32]struct{}{ - MessageInitiationType: {}, - MessageResponseType: {}, - MessageCookieReplyType: {}, - MessageTransportType: {}, - } - - // size will be different if same values - if len(isSameHeaderMap) != 4 { - errs = append(errs, ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - ), - ) - } - - isSameSizeMap := map[int]struct{}{ - newInitSize: {}, - newResponseSize: {}, - newCookieSize: {}, - newTransportSize: {}, - } - if len(isSameSizeMap) != 4 { errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, @@ -871,3 +838,67 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { return errors.Join(errs...) } + +var ErrContinueLoop = errors.New("continue processing") + +func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (msgType uint32, err error) { + // TODO: + // if awg.WaitResponse.ShouldWait.IsSet() { + // awg.WaitResponse.Channel <- struct{}{} + // } + + assumedMsgType, hasAssumedType := packetSizeToMsgType[size] + if !hasAssumedType { + return device.handleTransport(size, packet, bufsArrs) + } + + junkSize := msgTypeToJunkSize[assumedMsgType] + // transport size can align with other header types; + // making sure we have the right msgType + msgType, err = device.getMsgType(packet, junkSize) + if err != nil { + return 0, fmt.Errorf("aSec: get msg type: %w", err) + } + + if msgType == assumedMsgType { + *packet = (*packet)[junkSize:] + return msgType, nil + } + + device.log.Verbosef("transport packet lined up with another msg type") + return device.handleTransport(size, packet, bufsArrs) +} + +func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) { + msgType := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4]) + msgType, err := device.awg.GetLimitMin(msgType) + if err != nil { + return 0, fmt.Errorf("aSec: get limit min for message type %d: %w", msgType, err) + } + return msgType, nil +} + +func (device *Device) handleTransport(size int, packet *[]byte, bufsArrs *[MaxMessageSize]byte) (uint32, error) { + transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize + + msgType, err := device.getMsgType(packet, device.awg.ASecCfg.TransportHeaderJunkSize) + if err != nil { + return 0, fmt.Errorf("aSec: get msg type: %w", err) + } + + if msgType != MessageTransportType { + // probably a junk packet + return 0, fmt.Errorf("aSec: Received message with unknown type: %d", msgType) + } + + if transportJunkSize > 0 { + // remove junk from bufsArrs by shifting the packet + // this buffer is also used for decryption, so it needs to be corrected + copy((*bufsArrs)[:size], (*packet)[transportJunkSize:]) + size -= transportJunkSize + // need to reinitialize packet as well + (*packet) = (*packet)[:size] + } + + return msgType, nil +} diff --git a/device/device_test.go b/device/device_test.go index 6eb52e9..e3128e4 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -232,10 +232,10 @@ func TestAWGDevicePing(t *testing.T) { "jc", "5", "jmin", "500", "jmax", "1000", - "s1", "30", - "s2", "40", - "s3", "50", - "s4", "5", + "s1", "15", + "s2", "18", + "s3", "20", + "s4", "25", "h1", "123456", "h2", "67543", "h3", "123123", diff --git a/device/receive.go b/device/receive.go index 1aaba3a..d1ce3b6 100644 --- a/device/receive.go +++ b/device/receive.go @@ -140,37 +140,10 @@ func (device *Device) RoutineReceiveIncoming( packet := bufsArrs[i][:size] var msgType uint32 if device.isAWG() { - // TODO: - // if awg.WaitResponse.ShouldWait.IsSet() { - // awg.WaitResponse.Channel <- struct{}{} - // } - - if assumedMsgType, ok := packetSizeToMsgType[size]; ok { - junkSize := msgTypeToJunkSize[assumedMsgType] - // transport size can align with other header types; - // making sure we have the right msgType - msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4]) - if msgType == assumedMsgType { - packet = packet[junkSize:] - } else { - device.log.Verbosef("transport packet lined up with another msg type") - msgType = binary.LittleEndian.Uint32(packet[:4]) - } - } else { - transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize - msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4]) - if msgType != MessageTransportType { - // probably a junk packet - device.log.Verbosef("aSec: Received message with unknown type: %d", msgType) - continue - } - - // remove junk from bufsArrs by shifting the packet - // this buffer is also used for decryption, so it needs to be corrected - copy(bufsArrs[i][:size], packet[transportJunkSize:]) - size -= transportJunkSize - // need to reinitialize packet as well - packet = packet[:size] + msgType, err = device.Logic(size, &packet, bufsArrs[i]) + if err != nil { + device.log.Verbosef("awg device logic: %w", err) + continue } } else { msgType = binary.LittleEndian.Uint32(packet[:4]) diff --git a/device/uapi.go b/device/uapi.go index 49a08b4..1097e28 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -370,11 +370,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) tempAwg.ASecCfg.IsSet = true case "h1": - initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) - } - tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader) + awg.ParseMagicHeader(key, value, &tempAwg.ASecCfg.InitPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h2": diff --git a/go.mod b/go.mod index 7a72516..b000b63 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.24.4 require ( github.com/stretchr/testify v1.10.0 github.com/tevino/abool v1.2.0 - github.com/tevino/abool/v2 v2.1.0 + go.uber.org/atomic v1.11.0 golang.org/x/crypto v0.36.0 golang.org/x/net v0.37.0 golang.org/x/sys v0.31.0 diff --git a/go.sum b/go.sum index 2312a68..2860376 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,14 @@ github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= -github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA= +github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= @@ -18,5 +24,9 @@ golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ= gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=