From 65743536a2a1cfaec6b8b391b82ae5764b5a4a01 Mon Sep 17 00:00:00 2001 From: Mark Puha
"}, + args: args{name: "i1", input: ""}, wantErr: fmt.Errorf("invalid tag"), }, { name: "counter uniqueness violation", - args: args{input: ""}, + args: args{name: "i1", input: " "}, wantErr: fmt.Errorf("parse tag needs to be unique"), }, { name: "timestamp uniqueness violation", - args: args{input: " "}, + args: args{name: "i1", input: " "}, wantErr: fmt.Errorf("parse tag needs to be unique"), }, { @@ -58,7 +64,7 @@ func TestParse(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := Parse(tt.args.input) + _, err := Parse(tt.args.name, tt.args.input) // TODO: ErrorAs doesn't work as you think if tt.wantErr != nil { diff --git a/device/device.go b/device/device.go index 9ba69dd..7d17ef1 100644 --- a/device/device.go +++ b/device/device.go @@ -6,18 +6,18 @@ package device import ( + "errors" "runtime" "sync" "sync/atomic" "time" "github.com/amnezia-vpn/amneziawg-go/conn" - junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" + "github.com/amnezia-vpn/amneziawg-go/device/awg" "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/rwcancel" "github.com/amnezia-vpn/amneziawg-go/tun" - "github.com/tevino/abool/v2" ) type Version uint8 @@ -129,31 +129,7 @@ type Device struct { log *Logger version Version - awg awg -} - -type awg struct { - isASecOn abool.AtomicBool - // TODO: revision the need of the mutex - aSecMux sync.RWMutex - aSecCfg aSecCfgType - junkCreator junkCreator - - // TODO: determine if it's on - handshakeHandler junktag.SpecialHandshakeHandler -} - -type aSecCfgType struct { - isSet bool - junkPacketCount int - junkPacketMinSize int - junkPacketMaxSize int - initPacketJunkSize int - responsePacketJunkSize int - initPacketMagicHeader uint32 - responsePacketMagicHeader uint32 - underloadPacketMagicHeader uint32 - transportPacketMagicHeader uint32 + awg awg.Protocol } // deviceState represents the state of a Device. @@ -603,7 +579,7 @@ func (device *Device) BindClose() error { return err } func (device *Device) isAdvancedSecurityOn() bool { - return device.awg.isASecOn.IsSet() + return device.awg.IsASecOn.IsSet() } func (device *Device) resetProtocol() { @@ -614,165 +590,129 @@ func (device *Device) resetProtocol() { MessageTransportType = DefaultMessageTransportType } -func (device *Device) handlePostConfig(tempAwg *awg) (err error) { - - if !tempAwg.aSecCfg.isSet { - return err +func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { + if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet { + return nil } + var errs []error + isASecOn := false - device.awg.aSecMux.Lock() - if tempAwg.aSecCfg.junkPacketCount < 0 { - err = ipcErrorf( + device.awg.ASecMux.Lock() + if tempAwg.ASecCfg.JunkPacketCount < 0 { + errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", + ), ) } - device.awg.aSecCfg.junkPacketCount = tempAwg.aSecCfg.junkPacketCount - if tempAwg.aSecCfg.junkPacketCount != 0 { + device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount + if tempAwg.ASecCfg.JunkPacketCount != 0 { isASecOn = true } - device.awg.aSecCfg.junkPacketMinSize = tempAwg.aSecCfg.junkPacketMinSize - if tempAwg.aSecCfg.junkPacketMinSize != 0 { + device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize + if tempAwg.ASecCfg.JunkPacketMinSize != 0 { isASecOn = true } - if device.awg.aSecCfg.junkPacketCount > 0 && - tempAwg.aSecCfg.junkPacketMaxSize == tempAwg.aSecCfg.junkPacketMinSize { + if device.awg.ASecCfg.JunkPacketCount > 0 && + tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize { - tempAwg.aSecCfg.junkPacketMaxSize++ // to make rand gen work + tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work } - if tempAwg.aSecCfg.junkPacketMaxSize >= MaxSegmentSize { - device.awg.aSecCfg.junkPacketMinSize = 0 - device.awg.aSecCfg.junkPacketMaxSize = 1 - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", - tempAwg.aSecCfg.junkPacketMaxSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - tempAwg.aSecCfg.junkPacketMaxSize, - MaxSegmentSize, - ) - } - } else if tempAwg.aSecCfg.junkPacketMaxSize < tempAwg.aSecCfg.junkPacketMinSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "maxSize: %d; should be greater than minSize: %d; %w", - tempAwg.aSecCfg.junkPacketMaxSize, - tempAwg.aSecCfg.junkPacketMinSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "maxSize: %d; should be greater than minSize: %d", - tempAwg.aSecCfg.junkPacketMaxSize, - tempAwg.aSecCfg.junkPacketMinSize, - ) - } + if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize { + device.awg.ASecCfg.JunkPacketMinSize = 0 + device.awg.ASecCfg.JunkPacketMaxSize = 1 + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", + tempAwg.ASecCfg.JunkPacketMaxSize, + MaxSegmentSize, + )) + } else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + "maxSize: %d; should be greater than minSize: %d", + tempAwg.ASecCfg.JunkPacketMaxSize, + tempAwg.ASecCfg.JunkPacketMinSize, + )) } else { - device.awg.aSecCfg.junkPacketMaxSize = tempAwg.aSecCfg.junkPacketMaxSize + device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize } - if tempAwg.aSecCfg.junkPacketMaxSize != 0 { + if tempAwg.ASecCfg.JunkPacketMaxSize != 0 { isASecOn = true } - if MessageInitiationSize+tempAwg.aSecCfg.initPacketJunkSize >= MaxSegmentSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempAwg.aSecCfg.initPacketJunkSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.aSecCfg.initPacketJunkSize, - MaxSegmentSize, - ) - } + if MessageInitiationSize+tempAwg.ASecCfg.InitPacketJunkSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.InitPacketJunkSize, + MaxSegmentSize, + ), + ) } else { - device.awg.aSecCfg.initPacketJunkSize = tempAwg.aSecCfg.initPacketJunkSize + device.awg.ASecCfg.InitPacketJunkSize = tempAwg.ASecCfg.InitPacketJunkSize } - if tempAwg.aSecCfg.initPacketJunkSize != 0 { + if tempAwg.ASecCfg.InitPacketJunkSize != 0 { isASecOn = true } - if MessageResponseSize+tempAwg.aSecCfg.responsePacketJunkSize >= MaxSegmentSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempAwg.aSecCfg.responsePacketJunkSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.aSecCfg.responsePacketJunkSize, - MaxSegmentSize, - ) - } + if MessageResponseSize+tempAwg.ASecCfg.ResponsePacketJunkSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.ResponsePacketJunkSize, + MaxSegmentSize, + ), + ) } else { - device.awg.aSecCfg.responsePacketJunkSize = tempAwg.aSecCfg.responsePacketJunkSize + device.awg.ASecCfg.ResponsePacketJunkSize = tempAwg.ASecCfg.ResponsePacketJunkSize } - if tempAwg.aSecCfg.responsePacketJunkSize != 0 { + if tempAwg.ASecCfg.ResponsePacketJunkSize != 0 { isASecOn = true } - if tempAwg.aSecCfg.initPacketMagicHeader > 4 { + if tempAwg.ASecCfg.InitPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.awg.aSecCfg.initPacketMagicHeader = tempAwg.aSecCfg.initPacketMagicHeader - MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader + device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader + MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType } - if tempAwg.aSecCfg.responsePacketMagicHeader > 4 { + if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.awg.aSecCfg.responsePacketMagicHeader = tempAwg.aSecCfg.responsePacketMagicHeader - MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader + device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader + MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType } - if tempAwg.aSecCfg.underloadPacketMagicHeader > 4 { + if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.awg.aSecCfg.underloadPacketMagicHeader = tempAwg.aSecCfg.underloadPacketMagicHeader - MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader + device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader + MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType } - if tempAwg.aSecCfg.transportPacketMagicHeader > 4 { + if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.awg.aSecCfg.transportPacketMagicHeader = tempAwg.aSecCfg.transportPacketMagicHeader - MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader + device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader + MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType @@ -787,48 +727,28 @@ func (device *Device) handlePostConfig(tempAwg *awg) (err error) { // size will be different if same values if len(isSameMap) != 4 { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - ) - } + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, + MessageInitiationType, + MessageResponseType, + MessageCookieReplyType, + MessageTransportType, + ), + ) } - newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize - newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize + newInitSize := MessageInitiationSize + device.awg.ASecCfg.InitPacketJunkSize + newResponseSize := MessageResponseSize + device.awg.ASecCfg.ResponsePacketJunkSize if newInitSize == newResponseSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ; %w`, - newInitSize, - newResponseSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ`, - newInitSize, - newResponseSize, - ) - } + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `new init size:%d; and new response size:%d; should differ`, + newInitSize, + newResponseSize, + ), + ) } else { packetSizeToMsgType = map[int]uint32{ newInitSize: MessageInitiationType, @@ -838,23 +758,35 @@ func (device *Device) handlePostConfig(tempAwg *awg) (err error) { } msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize, - MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize, + MessageInitiationType: device.awg.ASecCfg.InitPacketJunkSize, + MessageResponseType: device.awg.ASecCfg.ResponsePacketJunkSize, MessageCookieReplyType: 0, MessageTransportType: 0, } } - if err := tempAwg.handshakeHandler.Validate(); err == nil { - return ipcErrorf(ipc.IpcErrorInvalid, "handle post config foo validate: %w", err) + device.awg.IsASecOn.SetTo(isASecOn) + var err error + device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg) + if err != nil { + errs = append(errs, err) } - device.awg.isASecOn.SetTo(isASecOn) - device.awg.junkCreator, err = NewJunkCreator(device) - device.awg.handshakeHandler = tempAwg.handshakeHandler - // TODO: - device.version = VersionAwgSpecialHandshake - device.awg.aSecMux.Unlock() + if tempAwg.HandshakeHandler.IsSet { + if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, "handshake handler validate: %w", err)) + } else { + device.awg.HandshakeHandler = tempAwg.HandshakeHandler + device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.version = VersionAwgSpecialHandshake + } + } else { + device.version = VersionAwg + } - return err + device.awg.ASecMux.Unlock() + + return errors.Join(errs...) } diff --git a/device/internal/junk-tag/tagged_junk_generator.go b/device/internal/junk-tag/tagged_junk_generator.go deleted file mode 100644 index a89edac..0000000 --- a/device/internal/junk-tag/tagged_junk_generator.go +++ /dev/null @@ -1,42 +0,0 @@ -package junktag - -import ( - "fmt" - "strconv" -) - -type TaggedJunkGenerator struct { - name string - packetSize int - generators []Generator -} - -func newTagedJunkGenerator(name string, size int) TaggedJunkGenerator { - return TaggedJunkGenerator{name: name, generators: make([]Generator, size)} -} - -func (tg *TaggedJunkGenerator) append(generator Generator) { - tg.generators = append(tg.generators, generator) - tg.packetSize += generator.Size() -} - -func (tg *TaggedJunkGenerator) generate() []byte { - packet := make([]byte, 0, tg.packetSize) - for _, generator := range tg.generators { - packet = append(packet, generator.Generate()...) - } - - return packet -} - -func (t *TaggedJunkGenerator) nameIndex() (int, error) { - if len(t.name) != 2 { - return 0, fmt.Errorf("name must be 2 character long: %s", t.name) - } - - index, err := strconv.Atoi(t.name[1:2]) - if err != nil { - return 0, fmt.Errorf("name should be 2 char long: %w", err) - } - return index, nil -} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index d774904..89c634c 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() if msg.Type != MessageInitiationType { - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() return nil } - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() msg.Type = MessageResponseType - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() if msg.Type != MessageResponseType { - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() return nil } - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() // lookup handshake by receiver diff --git a/device/peer.go b/device/peer.go index bdc52fc..9409ca6 100644 --- a/device/peer.go +++ b/device/peer.go @@ -137,7 +137,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { if err == nil { var totalLen uint64 for _, b := range buffers { - peer.device.awg.foo.PacketCounter++ + peer.device.awg.HandshakeHandler.PacketCounter++ totalLen += uint64(len(b)) } peer.txBytes.Add(totalLen) diff --git a/device/receive.go b/device/receive.go index f235a33..1875c76 100644 --- a/device/receive.go +++ b/device/receive.go @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -246,7 +246,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -305,7 +305,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() // handle cookie fields and ratelimiting @@ -457,7 +457,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index f49b0e1..2e446cd 100644 --- a/device/send.go +++ b/device/send.go @@ -128,19 +128,21 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { var junkedHeader []byte if peer.device.version >= VersionAwg { - junks := [][]byte{} + var junks [][]byte if peer.device.version == VersionAwgSpecialHandshake { - peer.device.awg.aSecMux.RLock() + peer.device.awg.ASecMux.RLock() // set junks depending on packet type - junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() + junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() if junks == nil { - junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() + junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() + } else { + junks = make([][]byte, peer.device.awg.ASecCfg.JunkPacketCount) } - peer.device.awg.aSecMux.RLock() - err := peer.device.awg.junkCreator.createJunkPackets(&junks) - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RLock() + err := peer.device.awg.JunkCreator.CreateJunkPackets(junks) + peer.device.awg.ASecMux.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -156,19 +158,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.awg.aSecMux.RLock() - if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) + peer.device.awg.ASecMux.RLock() + if peer.device.awg.ASecCfg.InitPacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.ASecCfg.InitPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) + err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.InitPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() return err } junkedHeader = writer.Bytes() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -206,19 +208,19 @@ func (peer *Peer) SendHandshakeResponse() error { } var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.awg.aSecMux.RLock() - if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) + peer.device.awg.ASecMux.RLock() + if peer.device.awg.ASecCfg.ResponsePacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.ASecCfg.ResponsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) + err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.ResponsePacketJunkSize) if err != nil { - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) return err } junkedHeader = writer.Bytes() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) diff --git a/device/uapi.go b/device/uapi.go index 7ae4b1c..216f7cd 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,7 +18,7 @@ import ( "sync" "time" - junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" + "github.com/amnezia-vpn/amneziawg-go/device/awg" "github.com/amnezia-vpn/amneziawg-go/ipc" ) @@ -99,33 +99,37 @@ func (device *Device) IpcGetOperation(w io.Writer) error { } if device.isAdvancedSecurityOn() { - if device.awg.aSecCfg.junkPacketCount != 0 { - sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) + if device.awg.ASecCfg.JunkPacketCount != 0 { + sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount) } - if device.awg.aSecCfg.junkPacketMinSize != 0 { - sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize) + if device.awg.ASecCfg.JunkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize) } - if device.awg.aSecCfg.junkPacketMaxSize != 0 { - sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize) + if device.awg.ASecCfg.JunkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize) } - if device.awg.aSecCfg.initPacketJunkSize != 0 { - sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize) + if device.awg.ASecCfg.InitPacketJunkSize != 0 { + sendf("s1=%d", device.awg.ASecCfg.InitPacketJunkSize) } - if device.awg.aSecCfg.responsePacketJunkSize != 0 { - sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize) + if device.awg.ASecCfg.ResponsePacketJunkSize != 0 { + sendf("s2=%d", device.awg.ASecCfg.ResponsePacketJunkSize) } - if device.awg.aSecCfg.initPacketMagicHeader != 0 { - sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader) + if device.awg.ASecCfg.InitPacketMagicHeader != 0 { + sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) } - if device.awg.aSecCfg.responsePacketMagicHeader != 0 { - sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader) + if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { + sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) } - if device.awg.aSecCfg.underloadPacketMagicHeader != 0 { - sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader) + if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 { + sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) } - if device.awg.aSecCfg.transportPacketMagicHeader != 0 { - sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader) + if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { + sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) } + + // for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator { + // sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) + // } } for _, peer := range device.peers.keyMap { @@ -181,7 +185,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempAwg := awg{} + tempAwg := awg.Protocol{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() @@ -238,7 +242,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { +func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error { switch key { case "private_key": var sk NoisePrivateKey @@ -290,8 +294,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempAwg.aSecCfg.junkPacketCount = junkPacketCount - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketCount = junkPacketCount + tempAwg.ASecCfg.IsSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) @@ -299,8 +303,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempAwg.aSecCfg.junkPacketMinSize = junkPacketMinSize - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize + tempAwg.ASecCfg.IsSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) @@ -308,8 +312,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempAwg.aSecCfg.junkPacketMaxSize = junkPacketMaxSize - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize + tempAwg.ASecCfg.IsSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) @@ -317,8 +321,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempAwg.aSecCfg.initPacketJunkSize = initPacketJunkSize - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.InitPacketJunkSize = initPacketJunkSize + tempAwg.ASecCfg.IsSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) @@ -326,65 +330,65 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempAwg.aSecCfg.responsePacketJunkSize = responsePacketJunkSize - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.ResponsePacketJunkSize = responsePacketJunkSize + tempAwg.ASecCfg.IsSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) } - tempAwg.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) } - tempAwg.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) } - tempAwg.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) } - tempAwg.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "i1", "i2", "i3", "i4", "i5": if len(value) == 0 { return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) } - generators, err := junktag.Parse(key, value) + generators, err := awg.Parse(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) } device.log.Verbosef("UAPI: Updating %s", key) - tempAwg.handshakeHandler.SpecialJunk.AppendGenerator(generators) - tempAwg.aSecCfg.isSet = true + tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators) + tempAwg.HandshakeHandler.IsSet = true case "j1", "j2", "j3": if len(value) == 0 { return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) } - generators, err := junktag.Parse(key, value) + generators, err := awg.Parse(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) } device.log.Verbosef("UAPI: Updating %s", key) - tempAwg.handshakeHandler.ControlledJunk.AppendGenerator(generators) - tempAwg.aSecCfg.isSet = true + tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators) + tempAwg.HandshakeHandler.IsSet = true case "itime": itime, err := strconv.ParseInt(value, 10, 64) if err != nil { @@ -392,8 +396,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { } device.log.Verbosef("UAPI: Updating itime %s", itime) - tempAwg.handshakeHandler.ITimeout = time.Duration(itime) - tempAwg.aSecCfg.isSet = true + tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) + tempAwg.HandshakeHandler.IsSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } diff --git a/go.mod b/go.mod index c150a4b..5896772 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,12 @@ go 1.24 require ( github.com/stretchr/testify v1.10.0 - github.com/tevino/abool/v2 v2.1.0 + github.com/tevino/abool v1.2.0 golang.org/x/crypto v0.36.0 golang.org/x/net v0.37.0 golang.org/x/sys v0.31.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 + gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f ) require ( diff --git a/go.sum b/go.sum index ec3fa7f..4b1e64a 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= -github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= +github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA= +github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= @@ -26,5 +26,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ= -gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= +gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM= +gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=