diff --git a/device/awg/awg.go b/device/awg/awg.go index ef97313..6138c5f 100644 --- a/device/awg/awg.go +++ b/device/awg/awg.go @@ -3,8 +3,6 @@ package awg import ( "bytes" "fmt" - "strconv" - "strings" "sync" "github.com/tevino/abool" @@ -19,145 +17,81 @@ type aSecCfgType struct { ResponseHeaderJunkSize int CookieReplyHeaderJunkSize int TransportHeaderJunkSize int - // InitPacketMagicHeader uint32 - // ResponsePacketMagicHeader uint32 - // UnderloadPacketMagicHeader uint32 - // TransportPacketMagicHeader uint32 - InitPacketMagicHeader Limit - ResponsePacketMagicHeader Limit - UnderloadPacketMagicHeader Limit - TransportPacketMagicHeader Limit -} - -type Limit struct { - Min uint32 - Max uint32 -} - -func NewLimitSameValue(value uint32) Limit { - return Limit{ - Min: value, - Max: value, - } -} - -func NewLimit(min, max uint32) (Limit, error) { - if min > max { - return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max) - } - - return Limit{ - Min: min, - Max: max, - }, nil -} - -func ParseMagicHeader(key, value string) (Limit, error) { - splitLimits := strings.Split(value, "-") - if len(splitLimits) != 2 { - magicHeader, err := strconv.ParseUint(value, 10, 32) - if err != nil { - return Limit{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err) - } - - return NewLimit(uint32(magicHeader), uint32(magicHeader)) - } - - min, err := strconv.ParseUint(splitLimits[0], 10, 32) - if err != nil { - return Limit{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err) - } - - max, err := strconv.ParseUint(splitLimits[1], 10, 32) - if err != nil { - return Limit{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err) - } - - limit, err := NewLimit(uint32(min), uint32(max)) - if err != nil { - return Limit{}, fmt.Errorf("new limit key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err) - } - - return limit, nil -} - -type Limits struct { - Limits []Limit - randomGenerator PRNG[uint32] -} - -func NewLimits(limits []Limit) Limits { - // TODO: check if limits doesn't overlap - return Limits{Limits: limits, randomGenerator: NewPRNG[uint32]()} -} - -func (l *Limits) Get(defaultMsgType uint32) (uint32, error) { - if defaultMsgType == 0 || defaultMsgType > 4 { - return 0, fmt.Errorf("invalid message type: %d", defaultMsgType) - } - - return l.randomGenerator.RandomSizeInRange(l.Limits[defaultMsgType-1].Min, l.Limits[defaultMsgType-1].Max), nil + InitPacketMagicHeader MagicHeader + ResponsePacketMagicHeader MagicHeader + UnderloadPacketMagicHeader MagicHeader + TransportPacketMagicHeader MagicHeader } type Protocol struct { IsASecOn abool.AtomicBool // TODO: revision the need of the mutex - ASecMux sync.RWMutex - ASecCfg aSecCfgType - JunkCreator junkCreator + ASecMux sync.RWMutex + ASecCfg aSecCfgType + JunkCreator junkCreator + MagicHeaders MagicHeaders HandshakeHandler SpecialHandshakeHandler - - MagicHeaders Limits } func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { + protocol.ASecMux.RLock() + defer protocol.ASecMux.RUnlock() + return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize, 0) } func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) { + protocol.ASecMux.RLock() + defer protocol.ASecMux.RUnlock() + return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize, 0) } func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) { + protocol.ASecMux.RLock() + defer protocol.ASecMux.RUnlock() + return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize, 0) } func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) { + protocol.ASecMux.RLock() + defer protocol.ASecMux.RUnlock() + return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize) } func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) { - var junk []byte - protocol.ASecMux.RLock() - - if junkSize != 0 { - buf := make([]byte, 0, junkSize+extraSize) - writer := bytes.NewBuffer(buf[:0]) - err := protocol.JunkCreator.AppendJunk(writer, junkSize) - if err != nil { - protocol.ASecMux.RUnlock() - return nil, err - } - junk = writer.Bytes() + if junkSize == 0 { + return nil, nil } - protocol.ASecMux.RUnlock() + + var junk []byte + buf := make([]byte, 0, junkSize+extraSize) + writer := bytes.NewBuffer(buf[:0]) + + err := protocol.JunkCreator.AppendJunk(writer, junkSize) + if err != nil { + return nil, fmt.Errorf("append junk: %w", err) + } + + junk = writer.Bytes() return junk, nil } -func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) { - fmt.Println(protocol.MagicHeaders.Limits) - for _, limit := range protocol.MagicHeaders.Limits { - if limit.Min <= msgType && msgType <= limit.Max { +func (protocol *Protocol) GetLimitMin(msgTypeRange uint32) (uint32, error) { + for _, limit := range protocol.MagicHeaders.headers { + if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max { return limit.Min, nil } } - return 0, fmt.Errorf("no limit found for message type: %d", msgType) + return 0, fmt.Errorf("no limit for range: %d", msgTypeRange) } -func (protocol *Protocol) Get(defaultMsgType uint32) (uint32, error) { +func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) { return protocol.MagicHeaders.Get(defaultMsgType) } diff --git a/device/awg/magic_header.go b/device/awg/magic_header.go new file mode 100644 index 0000000..f7ad45c --- /dev/null +++ b/device/awg/magic_header.go @@ -0,0 +1,91 @@ +package awg + +import ( + "cmp" + "fmt" + "slices" + "strconv" + "strings" +) + +type MagicHeader struct { + Min uint32 + Max uint32 +} + +func NewMagicHeaderSameValue(value uint32) MagicHeader { + return MagicHeader{Min: value, Max: value} +} + +func NewMagicHeader(min, max uint32) (MagicHeader, error) { + if min > max { + return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max) + } + + return MagicHeader{Min: min, Max: max}, nil +} + +func ParseMagicHeader(key, value string) (MagicHeader, error) { + splitLimits := strings.Split(value, "-") + if len(splitLimits) != 2 { + // if there is no hyphen, we treat it as single magic header value + magicHeader, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err) + } + + return NewMagicHeader(uint32(magicHeader), uint32(magicHeader)) + } + + min, err := strconv.ParseUint(splitLimits[0], 10, 32) + if err != nil { + return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err) + } + + max, err := strconv.ParseUint(splitLimits[1], 10, 32) + if err != nil { + return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err) + } + + magicHeader, err := NewMagicHeader(uint32(min), uint32(max)) + if err != nil { + return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err) + } + + return magicHeader, nil +} + +type MagicHeaders struct { + headers []MagicHeader + randomGenerator PRNG[uint32] +} + +func NewMagicHeaders(magicHeaders []MagicHeader) (MagicHeaders, error) { + if len(magicHeaders) != 4 { + return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", magicHeaders) + } + + sortedMagicHeaders := slices.SortedFunc(slices.Values(magicHeaders), func(lhs MagicHeader, rhs MagicHeader) int { + return cmp.Compare(lhs.Min, rhs.Min) + }) + + for i := range 3 { + if sortedMagicHeaders[i].Min > sortedMagicHeaders[i+1].Min { + return MagicHeaders{}, fmt.Errorf( + "magic headers shouldn't overlap; %v > %v", + sortedMagicHeaders[i-1].Min, + sortedMagicHeaders[i].Min, + ) + } + } + + return MagicHeaders{headers: magicHeaders, randomGenerator: NewPRNG[uint32]()}, nil +} + +func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) { + if defaultMsgType == 0 || defaultMsgType > 4 { + return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType) + } + + return mh.randomGenerator.RandomSizeInRange(mh.headers[defaultMsgType-1].Min, mh.headers[defaultMsgType-1].Max), nil +} diff --git a/device/device.go b/device/device.go index dbb975b..1f1c035 100644 --- a/device/device.go +++ b/device/device.go @@ -648,7 +648,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { isASecOn = true } - limits := make([]awg.Limit, 4) + limits := make([]awg.MagicHeader, 4) if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 { isASecOn = true @@ -659,7 +659,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType - limits[0] = awg.NewLimitSameValue(DefaultMessageInitiationType) + limits[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType) } if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 { @@ -671,7 +671,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType - limits[1] = awg.NewLimitSameValue(DefaultMessageResponseType) + limits[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType) } if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 { @@ -683,7 +683,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType - limits[2] = awg.NewLimitSameValue(DefaultMessageCookieReplyType) + limits[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType) } if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 { @@ -695,7 +695,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType - limits[3] = awg.NewLimitSameValue(DefaultMessageTransportType) + limits[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType) } isSameHeaderMap := map[uint32]struct{}{ @@ -705,7 +705,11 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { MessageTransportType: {}, } - device.awg.MagicHeaders = awg.NewLimits(limits) + var err error + device.awg.MagicHeaders, err = awg.NewMagicHeaders(limits) + if err != nil { + errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err)) + } // size will be different if same values if len(isSameHeaderMap) != 4 { @@ -859,8 +863,6 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize] } junkSize := msgTypeToJunkSize[assumedMsgType] - fmt.Println(msgTypeToJunkSize) - fmt.Printf("Assumed message type: %d; size: %d", assumedMsgType, junkSize) // transport size can align with other header types; // making sure we have the right msgType @@ -875,6 +877,7 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize] } device.log.Verbosef("transport packet lined up with another msg type") + return device.handleTransport(size, packet, bufsArrs) } @@ -883,8 +886,9 @@ func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) { msgType, err := device.awg.GetLimitMin(msgTypeRange) if err != nil { - return 0, fmt.Errorf("aSec: get limit min for message type range: %d; %w", msgTypeRange, err) + return 0, fmt.Errorf("get limit min: %w", err) } + return msgType, nil } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index aa7b0e6..bce2d2f 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -206,7 +206,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) device.awg.ASecMux.RLock() - msgType, err := device.awg.Get(DefaultMessageInitiationType) + msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType) if err != nil { device.awg.ASecMux.RUnlock() return nil, fmt.Errorf("get message type: %w", err) @@ -392,7 +392,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error var msg MessageResponse device.awg.ASecMux.RLock() - msg.Type, err = device.awg.Get(DefaultMessageResponseType) + msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType) if err != nil { device.awg.ASecMux.RUnlock() return nil, fmt.Errorf("get message type: %w", err) diff --git a/device/send.go b/device/send.go index 4d07066..b87dda2 100644 --- a/device/send.go +++ b/device/send.go @@ -149,11 +149,6 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.device.awg.JunkCreator.CreateJunkPackets(&junks) peer.device.awg.ASecMux.RUnlock() - if err != nil { - peer.device.log.Errorf("%v - %v", peer, err) - return err - } - if len(junks) > 0 { err = peer.SendBuffers(junks) @@ -242,7 +237,7 @@ func (device *Device) SendHandshakeCookie( device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) - msgType, err := device.awg.Get(DefaultMessageCookieReplyType) + msgType, err := device.awg.GetMsgType(DefaultMessageCookieReplyType) if err != nil { device.log.Errorf("Get message type for cookie reply: %v", err) return err @@ -535,7 +530,7 @@ func (device *Device) RoutineEncryption(id int) { fieldReceiver := header[4:8] fieldNonce := header[8:16] - msgType, err := device.awg.Get(DefaultMessageTransportType) + msgType, err := device.awg.GetMsgType(DefaultMessageTransportType) if err != nil { device.log.Errorf("get message type for transport: %v", err) continue diff --git a/device/uapi.go b/device/uapi.go index ce0152e..d99a4ee 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -371,39 +371,39 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) tempAwg.ASecCfg.IsSet = true case "h1": - magicHeader, err := awg.ParseMagicHeader(key, value) + initMagicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.InitPacketMagicHeader = magicHeader + tempAwg.ASecCfg.InitPacketMagicHeader = initMagicHeader tempAwg.ASecCfg.IsSet = true case "h2": - magicHeader, err := awg.ParseMagicHeader(key, value) + responseMagicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.ResponsePacketMagicHeader = magicHeader + tempAwg.ASecCfg.ResponsePacketMagicHeader = responseMagicHeader tempAwg.ASecCfg.IsSet = true case "h3": - magicHeader, err := awg.ParseMagicHeader(key, value) + cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.UnderloadPacketMagicHeader = magicHeader + tempAwg.ASecCfg.UnderloadPacketMagicHeader = cookieReplyMagicHeader tempAwg.ASecCfg.IsSet = true case "h4": - magicHeader, err := awg.ParseMagicHeader(key, value) + transportMagicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.TransportPacketMagicHeader = magicHeader + tempAwg.ASecCfg.TransportPacketMagicHeader = transportMagicHeader tempAwg.ASecCfg.IsSet = true case "i1", "i2", "i3", "i4", "i5":