diff --git a/device/awg/awg.go b/device/awg/awg.go index 838e7ca..ef97313 100644 --- a/device/awg/awg.go +++ b/device/awg/awg.go @@ -3,7 +3,6 @@ package awg import ( "bytes" "fmt" - "slices" "strconv" "strings" "sync" @@ -32,24 +31,29 @@ type aSecCfgType struct { } type Limit struct { - Min uint32 - Max uint32 - HeaderType uint32 + Min uint32 + Max uint32 } -func NewLimit(min, max, headerType uint32) (Limit, error) { +func NewLimitSameValue(value uint32) Limit { + return Limit{ + Min: value, + Max: value, + } +} + +func NewLimit(min, max uint32) (Limit, error) { if min > max { return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max) } return Limit{ - Min: min, - Max: max, - HeaderType: headerType, + Min: min, + Max: max, }, nil } -func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) { +func ParseMagicHeader(key, value string) (Limit, error) { splitLimits := strings.Split(value, "-") if len(splitLimits) != 2 { magicHeader, err := strconv.ParseUint(value, 10, 32) @@ -57,7 +61,7 @@ func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error return Limit{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err) } - return NewLimit(uint32(magicHeader), uint32(magicHeader), defaultHeaderType) + return NewLimit(uint32(magicHeader), uint32(magicHeader)) } min, err := strconv.ParseUint(splitLimits[0], 10, 32) @@ -70,7 +74,7 @@ func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error return Limit{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err) } - limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType) + limit, err := NewLimit(uint32(min), uint32(max)) if err != nil { return Limit{}, fmt.Errorf("new limit key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err) } @@ -78,19 +82,22 @@ func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error return limit, nil } -type Limits []Limit +type Limits struct { + Limits []Limit + randomGenerator PRNG[uint32] +} -func NewLimits(limits ...Limit) Limits { - slices.SortFunc(limits, func(a, b Limit) int { - if a.Min < b.Min { - return -1 - } else if a.Min > b.Min { - return 1 - } - return 0 - }) +func NewLimits(limits []Limit) Limits { + // TODO: check if limits doesn't overlap + return Limits{Limits: limits, randomGenerator: NewPRNG[uint32]()} +} - return Limits(limits) +func (l *Limits) Get(defaultMsgType uint32) (uint32, error) { + if defaultMsgType == 0 || defaultMsgType > 4 { + return 0, fmt.Errorf("invalid message type: %d", defaultMsgType) + } + + return l.randomGenerator.RandomSizeInRange(l.Limits[defaultMsgType-1].Min, l.Limits[defaultMsgType-1].Max), nil } type Protocol struct { @@ -102,7 +109,7 @@ type Protocol struct { HandshakeHandler SpecialHandshakeHandler - limits Limits + MagicHeaders Limits } func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { @@ -141,7 +148,8 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, } func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) { - for _, limit := range protocol.limits { + fmt.Println(protocol.MagicHeaders.Limits) + for _, limit := range protocol.MagicHeaders.Limits { if limit.Min <= msgType && msgType <= limit.Max { return limit.Min, nil } @@ -149,3 +157,7 @@ func (protocol *Protocol) GetLimitMin(msgType uint32) (uint32, error) { return 0, fmt.Errorf("no limit found for message type: %d", msgType) } + +func (protocol *Protocol) Get(defaultMsgType uint32) (uint32, error) { + return protocol.MagicHeaders.Get(defaultMsgType) +} diff --git a/device/awg/util.go b/device/awg/util.go index 5cb0caa..164e2b2 100644 --- a/device/awg/util.go +++ b/device/awg/util.go @@ -21,7 +21,7 @@ func NewPRNG[T constraints.Integer]() PRNG[T] { } func (p PRNG[T]) RandomSizeInRange(min, max T) T { - if min >= max { + if min > max { panic("min must be less than max") } diff --git a/device/cookie.go b/device/cookie.go index 876f05d..5dfb8f5 100644 --- a/device/cookie.go +++ b/device/cookie.go @@ -118,6 +118,7 @@ func (st *CookieChecker) CreateReply( msg []byte, recv uint32, src []byte, + msgType uint32, ) (*MessageCookieReply, error) { st.RLock() @@ -153,7 +154,7 @@ func (st *CookieChecker) CreateReply( smac1 := smac2 - blake2s.Size128 reply := new(MessageCookieReply) - reply.Type = MessageCookieReplyType + reply.Type = msgType reply.Receiver = recv _, err := rand.Read(reply.Nonce[:]) diff --git a/device/cookie_test.go b/device/cookie_test.go index 4f1e50a..d076004 100644 --- a/device/cookie_test.go +++ b/device/cookie_test.go @@ -99,7 +99,7 @@ func TestCookieMAC1(t *testing.T) { 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, } generator.AddMacs(msg) - reply, err := checker.CreateReply(msg, 1377, src) + reply, err := checker.CreateReply(msg, 1377, src, DefaultMessageCookieReplyType) if err != nil { t.Fatal("Failed to create cookie reply:", err) } diff --git a/device/device.go b/device/device.go index a0e04cb..22b73d0 100644 --- a/device/device.go +++ b/device/device.go @@ -648,24 +648,29 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { isASecOn = true } + limits := make([]awg.Limit, 4) + if tempAwg.ASecCfg.InitPacketMagicHeader.Min > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader + limits[0] = tempAwg.ASecCfg.InitPacketMagicHeader MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader.Min } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType + limits[0] = awg.NewLimitSameValue(DefaultMessageInitiationType) } if tempAwg.ASecCfg.ResponsePacketMagicHeader.Min > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader.Min + limits[1] = tempAwg.ASecCfg.ResponsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType + limits[1] = awg.NewLimitSameValue(DefaultMessageResponseType) } if tempAwg.ASecCfg.UnderloadPacketMagicHeader.Min > 4 { @@ -673,9 +678,11 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { device.log.Verbosef("UAPI: Updating underload_packet_magic_header") device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader.Min + limits[2] = tempAwg.ASecCfg.UnderloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType + limits[2] = awg.NewLimitSameValue(DefaultMessageCookieReplyType) } if tempAwg.ASecCfg.TransportPacketMagicHeader.Min > 4 { @@ -683,9 +690,11 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { device.log.Verbosef("UAPI: Updating transport_packet_magic_header") device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader.Min + limits[3] = tempAwg.ASecCfg.UnderloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType + limits[3] = awg.NewLimitSameValue(DefaultMessageTransportType) } isSameHeaderMap := map[uint32]struct{}{ @@ -708,6 +717,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { ) } + device.awg.MagicHeaders = awg.NewLimits(limits) + newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize if newInitSize >= MaxSegmentSize { @@ -814,11 +825,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { } device.awg.IsASecOn.SetTo(isASecOn) - var err error - device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg) - if err != nil { - errs = append(errs, err) - } + device.awg.JunkCreator = awg.NewJunkCreator(device.awg.ASecCfg) if tempAwg.HandshakeHandler.IsSet { if err := tempAwg.HandshakeHandler.Validate(); err != nil { diff --git a/device/noise-protocol.go b/device/noise-protocol.go index ff1ecc1..57a677d 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -206,8 +206,14 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) device.awg.ASecMux.RLock() + msgType, err := device.awg.Get(DefaultMessageInitiationType) + if err != nil { + device.awg.ASecMux.RUnlock() + return nil, fmt.Errorf("get message type: %w", err) + } + msg := MessageInitiation{ - Type: MessageInitiationType, + Type: msgType, Ephemeral: handshake.localEphemeral.publicKey(), } device.awg.ASecMux.RUnlock() @@ -385,7 +391,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error var msg MessageResponse device.awg.ASecMux.RLock() - msg.Type = MessageResponseType + msg.Type, err = device.awg.Get(DefaultMessageResponseType) + if err != nil { + device.awg.ASecMux.RUnlock() + return nil, fmt.Errorf("get message type: %w", err) + } + device.awg.ASecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex diff --git a/device/send.go b/device/send.go index f0db894..7a4bb16 100644 --- a/device/send.go +++ b/device/send.go @@ -146,7 +146,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount) } peer.device.awg.ASecMux.RLock() - err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks) + peer.device.awg.JunkCreator.CreateJunkPackets(&junks) peer.device.awg.ASecMux.RUnlock() if err != nil { @@ -242,10 +242,17 @@ func (device *Device) SendHandshakeCookie( device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) + msgType, err := device.awg.Get(DefaultMessageCookieReplyType) + if err != nil { + device.log.Errorf("Get message type for cookie reply: %v", err) + return err + } + reply, err := device.cookieChecker.CreateReply( initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes(), + msgType, ) if err != nil { device.log.Errorf("Failed to create cookie reply: %v", err) @@ -528,7 +535,12 @@ func (device *Device) RoutineEncryption(id int) { fieldReceiver := header[4:8] fieldNonce := header[8:16] - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + msgType, err := device.awg.Get(DefaultMessageTransportType) + if err != nil { + device.log.Errorf("Get message type for transport: %v", err) + continue + } + binary.LittleEndian.PutUint32(fieldType, msgType) binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) diff --git a/device/uapi.go b/device/uapi.go index 1097e28..ce0152e 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -120,18 +120,19 @@ func (device *Device) IpcGetOperation(w io.Writer) error { if device.awg.ASecCfg.TransportHeaderJunkSize != 0 { sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize) } - if device.awg.ASecCfg.InitPacketMagicHeader != 0 { - sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) - } - if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { - sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) - } - if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 { - sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) - } - if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { - sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) - } + // TODO: + // if device.awg.ASecCfg.InitPacketMagicHeader != 0 { + // sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) + // } + // if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { + // sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) + // } + // if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 { + // sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) + // } + // if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { + // sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) + // } specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields() for _, field := range specialJunkIpcFields { @@ -370,33 +371,41 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) tempAwg.ASecCfg.IsSet = true case "h1": - awg.ParseMagicHeader(key, value, &tempAwg.ASecCfg.InitPacketMagicHeader) + magicHeader, err := awg.ParseMagicHeader(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) + } + tempAwg.ASecCfg.InitPacketMagicHeader = magicHeader tempAwg.ASecCfg.IsSet = true case "h2": - responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) + magicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader) + + tempAwg.ASecCfg.ResponsePacketMagicHeader = magicHeader tempAwg.ASecCfg.IsSet = true case "h3": - underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) + magicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + + tempAwg.ASecCfg.UnderloadPacketMagicHeader = magicHeader tempAwg.ASecCfg.IsSet = true case "h4": - transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) + magicHeader, err := awg.ParseMagicHeader(key, value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err) } - tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader) + + tempAwg.ASecCfg.TransportPacketMagicHeader = magicHeader tempAwg.ASecCfg.IsSet = true + case "i1", "i2", "i3", "i4", "i5": if len(value) == 0 { device.log.Verbosef("UAPI: received empty %s", key)