From 339134bc2574f24863ba8487c602e4f609ea68ca Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sat, 1 Feb 2025 16:45:30 +0100 Subject: [PATCH 1/3] AWG-2: delete obsolate junk impl --- device/util.go | 25 ------------------------- device/util_test.go | 27 --------------------------- 2 files changed, 52 deletions(-) delete mode 100644 device/util.go delete mode 100644 device/util_test.go diff --git a/device/util.go b/device/util.go deleted file mode 100644 index aab8ab7..0000000 --- a/device/util.go +++ /dev/null @@ -1,25 +0,0 @@ -package device - -import ( - "bytes" - crand "crypto/rand" - "fmt" -) - -func appendJunk(writer *bytes.Buffer, size int) error { - headerJunk, err := randomJunkWithSize(size) - if err != nil { - return fmt.Errorf("failed to create header junk: %v", err) - } - _, err = writer.Write(headerJunk) - if err != nil { - return fmt.Errorf("failed to write header junk: %v", err) - } - return nil -} - -func randomJunkWithSize(size int) ([]byte, error) { - junk := make([]byte, size) - _, err := crand.Read(junk) - return junk, err -} diff --git a/device/util_test.go b/device/util_test.go deleted file mode 100644 index c061eef..0000000 --- a/device/util_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package device - -import ( - "bytes" - "fmt" - "testing" -) - -func Test_randomJunktWithSize(t *testing.T) { - junk, err := randomJunkWithSize(30) - fmt.Println(string(junk), len(junk), err) -} - -func Test_appendJunk(t *testing.T) { - t.Run("", func(t *testing.T) { - s := "apple" - buffer := bytes.NewBuffer([]byte(s)) - err := appendJunk(buffer, 30) - if err != nil && - buffer.Len() != len(s)+30 { - t.Errorf("appendWithJunk() size don't match") - } - read := make([]byte, 50) - buffer.Read(read) - fmt.Println(string(read)) - }) -} From 9a56c052cc591e1fcf4474ae323e52bb59a9b8fb Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sat, 1 Feb 2025 16:45:59 +0100 Subject: [PATCH 2/3] AWG-2 update go version --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 115ae88..b03acb0 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/amnezia-vpn/amneziawg-go -go 1.22.3 +go 1.23 require ( github.com/tevino/abool/v2 v2.1.0 From 971144c9fb039722299b2e13632b1b19ad221571 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sat, 1 Feb 2025 16:46:38 +0100 Subject: [PATCH 3/3] AWG-2 create&integrate junk_creator --- device/device.go | 31 +++++++-- device/device_test.go | 5 +- device/junk_creator.go | 74 +++++++++++++++++++++ device/junk_creator_test.go | 124 ++++++++++++++++++++++++++++++++++++ device/send.go | 107 ++++++++++++++++++------------- main.go | 7 +- 6 files changed, 294 insertions(+), 54 deletions(-) create mode 100644 device/junk_creator.go create mode 100644 device/junk_creator_test.go diff --git a/device/device.go b/device/device.go index 80e3793..1bedce3 100644 --- a/device/device.go +++ b/device/device.go @@ -95,6 +95,8 @@ type Device struct { isASecOn abool.AtomicBool aSecMux sync.RWMutex aSecCfg aSecCfgType + + junkCreator junkCreator } type aSecCfgType struct { @@ -161,7 +163,10 @@ func (device *Device) changeState(want deviceState) (err error) { old := device.deviceState() if old == deviceStateClosed { // once closed, always closed - device.log.Verbosef("Interface closed, ignored requested state %s", want) + device.log.Verbosef( + "Interface closed, ignored requested state %s", + want, + ) return nil } switch want { @@ -182,7 +187,11 @@ func (device *Device) changeState(want deviceState) (err error) { } } device.log.Verbosef( - "Interface state was %s, requested %s, now %s", old, want, device.deviceState()) + "Interface state was %s, requested %s, now %s", + old, + want, + device.deviceState(), + ) return } @@ -287,7 +296,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret( + handshake.remoteStatic, + ) expiredPeers = append(expiredPeers, peer) } @@ -433,7 +444,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { device.peers.RLock() for _, peer := range device.peers.keyMap { peer.keypairs.RLock() - sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) + sendKeepalive := peer.keypairs.current != nil && + !peer.keypairs.current.created.Add(RejectAfterTime). + Before(time.Now()) peer.keypairs.RUnlock() if sendKeepalive { peer.SendKeepalive() @@ -539,8 +552,12 @@ func (device *Device) BindUpdate() error { // start receiving routines device.net.stopping.Add(len(recvFns)) - device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption - device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + device.queue.decryption.wg.Add( + len(recvFns), + ) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add( + len(recvFns), + ) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake batchSize := netc.bind.BatchSize() for _, fn := range recvFns { go device.RoutineReceiveIncoming(batchSize, fn) @@ -569,7 +586,6 @@ func (device *Device) resetProtocol() { } func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { - if !tempASecCfg.isSet { return err } @@ -799,6 +815,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { } device.isASecOn.SetTo(isASecOn) + device.junkCreator, err = NewJunkCreator(device) device.aSecMux.Unlock() return err diff --git a/device/device_test.go b/device/device_test.go index e6664a6..e904f26 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -109,7 +109,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "replace_peers", "true", "jc", "5", "jmin", "500", - "jmax", "501", + "jmax", "1000", "s1", "30", "s2", "40", "h1", "123456", @@ -131,7 +131,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "replace_peers", "true", "jc", "5", "jmin", "500", - "jmax", "501", + "jmax", "1000", "s1", "30", "s2", "40", "h1", "123456", @@ -274,6 +274,7 @@ func TestTwoDevicePing(t *testing.T) { }) } +// Run test with -race=false to avoid the race for setting the default msgTypes 2 times func TestTwoDevicePingASecurity(t *testing.T) { goroutineLeakCheck(t) pair := genTestPair(t, true, true) diff --git a/device/junk_creator.go b/device/junk_creator.go new file mode 100644 index 0000000..85a5bbc --- /dev/null +++ b/device/junk_creator.go @@ -0,0 +1,74 @@ +package device + +import ( + "bytes" + crand "crypto/rand" + "fmt" + v2 "math/rand/v2" +) + +type junkCreator struct { + device *Device + cha8Rand *v2.ChaCha8 +} + +func NewJunkCreator(d *Device) (junkCreator, error) { + buf := make([]byte, 32) + _, err := crand.Read(buf) + if err != nil { + return junkCreator{}, err + } + return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil +} + +// Should be called with aSecMux RLocked +func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) { + if jc.device.aSecCfg.junkPacketCount == 0 { + return nil, nil + } + + junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) + for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { + packetSize := jc.randomPacketSize() + junk, err := jc.randomJunkWithSize(packetSize) + if err != nil { + jc.device.log.Errorf( + "%v - Failed to create junk packet: %v", + peer, + err, + ) + return nil, err + } + junks = append(junks, junk) + } + return junks, nil +} + +// Should be called with aSecMux RLocked +func (jc *junkCreator) randomPacketSize() int { + return int( + jc.cha8Rand.Uint64()%uint64( + jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, + ), + ) + jc.device.aSecCfg.junkPacketMinSize +} + +// Should be called with aSecMux RLocked +func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { + headerJunk, err := jc.randomJunkWithSize(size) + if err != nil { + return fmt.Errorf("failed to create header junk: %v", err) + } + _, err = writer.Write(headerJunk) + if err != nil { + return fmt.Errorf("failed to write header junk: %v", err) + } + return nil +} + +// Should be called with aSecMux RLocked +func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { + junk := make([]byte, size) + _, err := jc.cha8Rand.Read(junk) + return junk, err +} diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go new file mode 100644 index 0000000..6f63360 --- /dev/null +++ b/device/junk_creator_test.go @@ -0,0 +1,124 @@ +package device + +import ( + "bytes" + "fmt" + "testing" + + "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" + "github.com/amnezia-vpn/amneziawg-go/tun/tuntest" +) + +func setUpJunkCreator(t *testing.T) (junkCreator, error) { + cfg, _ := genASecurityConfigs(t) + tun := tuntest.NewChannelTUN() + binds := bindtest.NewChannelBinds() + level := LogLevelVerbose + dev := NewDevice( + tun.TUN(), + binds[0], + NewLogger(level, ""), + ) + + if err := dev.IpcSet(cfg[0]); err != nil { + t.Errorf("failed to configure device %v", err) + dev.Close() + return junkCreator{}, err + } + + jc, err := NewJunkCreator(dev) + + if err != nil { + t.Errorf("failed to create junk creator %v", err) + dev.Close() + return junkCreator{}, err + } + + return jc, nil +} + +func Test_junkCreator_createJunkPackets(t *testing.T) { + jc, err := setUpJunkCreator(t) + if err != nil { + return + } + t.Run("", func(t *testing.T) { + got, err := jc.createJunkPackets(nil) + if err != nil { + t.Errorf( + "junkCreator.createJunkPackets() = %v; failed", + err, + ) + return + } + seen := make(map[string]bool) + for _, junk := range got { + key := string(junk) + if seen[key] { + t.Errorf( + "junkCreator.createJunkPackets() = %v, duplicate key: %v", + got, + junk, + ) + return + } + seen[key] = true + } + }) +} + +func Test_junkCreator_randomJunkWithSize(t *testing.T) { + t.Run("", func(t *testing.T) { + jc, err := setUpJunkCreator(t) + if err != nil { + return + } + r1, _ := jc.randomJunkWithSize(10) + r2, _ := jc.randomJunkWithSize(10) + fmt.Printf("%v\n%v\n", r1, r2) + if bytes.Equal(r1, r2) { + t.Errorf("same junks %v", err) + jc.device.Close() + return + } + }) +} + +func Test_junkCreator_randomPacketSize(t *testing.T) { + jc, err := setUpJunkCreator(t) + if err != nil { + return + } + for range [30]struct{}{} { + t.Run("", func(t *testing.T) { + if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || + got > jc.device.aSecCfg.junkPacketMaxSize { + t.Errorf( + "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", + got, + jc.device.aSecCfg.junkPacketMinSize, + jc.device.aSecCfg.junkPacketMaxSize, + ) + } + }) + } +} + +func Test_junkCreator_appendJunk(t *testing.T) { + jc, err := setUpJunkCreator(t) + if err != nil { + return + } + t.Run("", func(t *testing.T) { + s := "apple" + buffer := bytes.NewBuffer([]byte(s)) + err := jc.appendJunk(buffer, 30) + if err != nil && + buffer.Len() != len(s)+30 { + t.Errorf("appendWithJunk() size don't match") + } + read := make([]byte, 50) + buffer.Read(read) + fmt.Println(string(read)) + }) +} diff --git a/device/send.go b/device/send.go index 1b4406d..5c54d4d 100644 --- a/device/send.go +++ b/device/send.go @@ -9,7 +9,6 @@ import ( "bytes" "encoding/binary" "errors" - "math/rand" "net" "os" "sync" @@ -121,7 +120,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { msg, err := peer.device.CreateMessageInitiation(peer) if err != nil { - peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) + peer.device.log.Errorf( + "%v - Failed to create initiation message: %v", + peer, + err, + ) return err } var sendBuffer [][]byte @@ -129,7 +132,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { peer.device.aSecMux.RLock() - junks, err := peer.createJunkPackets() + junks, err := peer.device.junkCreator.createJunkPackets(peer) peer.device.aSecMux.RUnlock() if err != nil { @@ -141,7 +144,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { err = peer.SendBuffers(junks) if err != nil { - peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err) + peer.device.log.Errorf( + "%v - Failed to send junk packets: %v", + peer, + err, + ) return err } } @@ -150,7 +157,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if peer.device.aSecCfg.initPacketJunkSize != 0 { buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) + err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) peer.device.aSecMux.RUnlock() @@ -175,7 +182,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { err = peer.SendBuffers(sendBuffer) if err != nil { - peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) + peer.device.log.Errorf( + "%v - Failed to send handshake initiation: %v", + peer, + err, + ) } peer.timersHandshakeInitiated() @@ -191,7 +202,11 @@ func (peer *Peer) SendHandshakeResponse() error { response, err := peer.device.CreateMessageResponse(peer) if err != nil { - peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) + peer.device.log.Errorf( + "%v - Failed to create response message: %v", + peer, + err, + ) return err } var junkedHeader []byte @@ -200,7 +215,7 @@ func (peer *Peer) SendHandshakeResponse() error { if peer.device.aSecCfg.responsePacketJunkSize != 0 { buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) + err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) if err != nil { peer.device.aSecMux.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) @@ -231,7 +246,11 @@ func (peer *Peer) SendHandshakeResponse() error { // TODO: allocation could be avoided err = peer.SendBuffers([][]byte{junkedHeader}) if err != nil { - peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) + peer.device.log.Errorf( + "%v - Failed to send handshake response: %v", + peer, + err, + ) } return err } @@ -239,7 +258,10 @@ func (peer *Peer) SendHandshakeResponse() error { func (device *Device) SendHandshakeCookie( initiatingElem *QueueHandshakeElement, ) error { - device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) + device.log.Verbosef( + "Sending cookie response for denied handshake message for %v", + initiatingElem.endpoint.DstToString(), + ) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) reply, err := device.cookieChecker.CreateReply( @@ -266,7 +288,8 @@ func (peer *Peer) keepKeyFreshSending() { return } nonce := keypair.sendNonce.Load() - if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { + if nonce > RekeyAfterMessages || + (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { peer.SendHandshakeInitiation(false) } } @@ -369,12 +392,18 @@ func (device *Device) RoutineReadFromTUN() { // TODO: record stat for this // This will happen if MSS is surprisingly small (< 576) // coincident with reasonably high throughput. - device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) + device.log.Verbosef( + "Dropped some packets from multi-segment read: %v", + readErr, + ) continue } if !device.isClosed() { if !errors.Is(readErr, os.ErrClosed) { - device.log.Errorf("Failed to read packet from TUN device: %v", readErr) + device.log.Errorf( + "Failed to read packet from TUN device: %v", + readErr, + ) } go device.Close() } @@ -409,7 +438,8 @@ top: } keypair := peer.keypairs.Current() - if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { + if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || + time.Since(keypair.created) >= RejectAfterTime { peer.SendHandshakeInitiation(false) return } @@ -427,7 +457,10 @@ top: if elemsContainerOOO == nil { elemsContainerOOO = peer.device.GetOutboundElementsContainer() } - elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) + elemsContainerOOO.elems = append( + elemsContainerOOO.elems, + elem, + ) continue } else { elemsContainer.elems[i] = elem @@ -440,7 +473,9 @@ top: elemsContainer.elems = elemsContainer.elems[:i] if elemsContainerOOO != nil { - peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans + peer.StagePackets( + elemsContainerOOO, + ) // XXX: Out of order, but we can't front-load go chans } if len(elemsContainer.elems) == 0 { @@ -469,31 +504,6 @@ top: } } -func (peer *Peer) createJunkPackets() ([][]byte, error) { - if peer.device.aSecCfg.junkPacketCount == 0 { - return nil, nil - } - - junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount) - for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ { - packetSize := rand.Intn( - peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize, - ) + peer.device.aSecCfg.junkPacketMinSize - - junk, err := randomJunkWithSize(packetSize) - if err != nil { - peer.device.log.Errorf( - "%v - Failed to create junk packet: %v", - peer, - err, - ) - return nil, err - } - junks = append(junks, junk) - } - return junks, nil -} - func (peer *Peer) FlushStagedPackets() { for { select { @@ -546,11 +556,17 @@ func (device *Device) RoutineEncryption(id int) { fieldNonce := header[8:16] binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint32( + fieldReceiver, + elem.keypair.remoteIndex, + ) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + paddingSize := calculatePaddingSize( + len(elem.packet), + int(device.tun.mtu.Load()), + ) elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) // encrypt content and release to consumer @@ -570,7 +586,10 @@ func (device *Device) RoutineEncryption(id int) { func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device := peer.device defer func() { - defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) + defer device.log.Verbosef( + "%v - Routine: sequential sender - stopped", + peer, + ) peer.stopping.Done() }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) diff --git a/main.go b/main.go index 5a3dfef..77a379a 100644 --- a/main.go +++ b/main.go @@ -59,7 +59,12 @@ func warning() { func main() { if len(os.Args) == 2 && os.Args[1] == "--version" { - fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH) + fmt.Printf( + "amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", + Version, + runtime.GOOS, + runtime.GOARCH, + ) return }