From 339134bc2574f24863ba8487c602e4f609ea68ca Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sat, 1 Feb 2025 16:45:30 +0100
Subject: [PATCH 1/3] AWG-2: delete obsolate junk impl
---
device/util.go | 25 -------------------------
device/util_test.go | 27 ---------------------------
2 files changed, 52 deletions(-)
delete mode 100644 device/util.go
delete mode 100644 device/util_test.go
diff --git a/device/util.go b/device/util.go
deleted file mode 100644
index aab8ab7..0000000
--- a/device/util.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package device
-
-import (
- "bytes"
- crand "crypto/rand"
- "fmt"
-)
-
-func appendJunk(writer *bytes.Buffer, size int) error {
- headerJunk, err := randomJunkWithSize(size)
- if err != nil {
- return fmt.Errorf("failed to create header junk: %v", err)
- }
- _, err = writer.Write(headerJunk)
- if err != nil {
- return fmt.Errorf("failed to write header junk: %v", err)
- }
- return nil
-}
-
-func randomJunkWithSize(size int) ([]byte, error) {
- junk := make([]byte, size)
- _, err := crand.Read(junk)
- return junk, err
-}
diff --git a/device/util_test.go b/device/util_test.go
deleted file mode 100644
index c061eef..0000000
--- a/device/util_test.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package device
-
-import (
- "bytes"
- "fmt"
- "testing"
-)
-
-func Test_randomJunktWithSize(t *testing.T) {
- junk, err := randomJunkWithSize(30)
- fmt.Println(string(junk), len(junk), err)
-}
-
-func Test_appendJunk(t *testing.T) {
- t.Run("", func(t *testing.T) {
- s := "apple"
- buffer := bytes.NewBuffer([]byte(s))
- err := appendJunk(buffer, 30)
- if err != nil &&
- buffer.Len() != len(s)+30 {
- t.Errorf("appendWithJunk() size don't match")
- }
- read := make([]byte, 50)
- buffer.Read(read)
- fmt.Println(string(read))
- })
-}
From 9a56c052cc591e1fcf4474ae323e52bb59a9b8fb Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sat, 1 Feb 2025 16:45:59 +0100
Subject: [PATCH 2/3] AWG-2 update go version
---
go.mod | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/go.mod b/go.mod
index 115ae88..b03acb0 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/amnezia-vpn/amneziawg-go
-go 1.22.3
+go 1.23
require (
github.com/tevino/abool/v2 v2.1.0
From 971144c9fb039722299b2e13632b1b19ad221571 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sat, 1 Feb 2025 16:46:38 +0100
Subject: [PATCH 3/3] AWG-2 create&integrate junk_creator
---
device/device.go | 31 +++++++--
device/device_test.go | 5 +-
device/junk_creator.go | 74 +++++++++++++++++++++
device/junk_creator_test.go | 124 ++++++++++++++++++++++++++++++++++++
device/send.go | 107 ++++++++++++++++++-------------
main.go | 7 +-
6 files changed, 294 insertions(+), 54 deletions(-)
create mode 100644 device/junk_creator.go
create mode 100644 device/junk_creator_test.go
diff --git a/device/device.go b/device/device.go
index 80e3793..1bedce3 100644
--- a/device/device.go
+++ b/device/device.go
@@ -95,6 +95,8 @@ type Device struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
+
+ junkCreator junkCreator
}
type aSecCfgType struct {
@@ -161,7 +163,10 @@ func (device *Device) changeState(want deviceState) (err error) {
old := device.deviceState()
if old == deviceStateClosed {
// once closed, always closed
- device.log.Verbosef("Interface closed, ignored requested state %s", want)
+ device.log.Verbosef(
+ "Interface closed, ignored requested state %s",
+ want,
+ )
return nil
}
switch want {
@@ -182,7 +187,11 @@ func (device *Device) changeState(want deviceState) (err error) {
}
}
device.log.Verbosef(
- "Interface state was %s, requested %s, now %s", old, want, device.deviceState())
+ "Interface state was %s, requested %s, now %s",
+ old,
+ want,
+ device.deviceState(),
+ )
return
}
@@ -287,7 +296,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
- handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(
+ handshake.remoteStatic,
+ )
expiredPeers = append(expiredPeers, peer)
}
@@ -433,7 +444,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.keypairs.RLock()
- sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
+ sendKeepalive := peer.keypairs.current != nil &&
+ !peer.keypairs.current.created.Add(RejectAfterTime).
+ Before(time.Now())
peer.keypairs.RUnlock()
if sendKeepalive {
peer.SendKeepalive()
@@ -539,8 +552,12 @@ func (device *Device) BindUpdate() error {
// start receiving routines
device.net.stopping.Add(len(recvFns))
- device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
- device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ device.queue.decryption.wg.Add(
+ len(recvFns),
+ ) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+ device.queue.handshake.wg.Add(
+ len(recvFns),
+ ) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
go device.RoutineReceiveIncoming(batchSize, fn)
@@ -569,7 +586,6 @@ func (device *Device) resetProtocol() {
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
-
if !tempASecCfg.isSet {
return err
}
@@ -799,6 +815,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
device.isASecOn.SetTo(isASecOn)
+ device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
return err
diff --git a/device/device_test.go b/device/device_test.go
index e6664a6..e904f26 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -109,7 +109,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true",
"jc", "5",
"jmin", "500",
- "jmax", "501",
+ "jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
@@ -131,7 +131,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true",
"jc", "5",
"jmin", "500",
- "jmax", "501",
+ "jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
@@ -274,6 +274,7 @@ func TestTwoDevicePing(t *testing.T) {
})
}
+// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestTwoDevicePingASecurity(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
diff --git a/device/junk_creator.go b/device/junk_creator.go
new file mode 100644
index 0000000..85a5bbc
--- /dev/null
+++ b/device/junk_creator.go
@@ -0,0 +1,74 @@
+package device
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "fmt"
+ v2 "math/rand/v2"
+)
+
+type junkCreator struct {
+ device *Device
+ cha8Rand *v2.ChaCha8
+}
+
+func NewJunkCreator(d *Device) (junkCreator, error) {
+ buf := make([]byte, 32)
+ _, err := crand.Read(buf)
+ if err != nil {
+ return junkCreator{}, err
+ }
+ return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
+}
+
+// Should be called with aSecMux RLocked
+func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
+ if jc.device.aSecCfg.junkPacketCount == 0 {
+ return nil, nil
+ }
+
+ junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
+ for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
+ packetSize := jc.randomPacketSize()
+ junk, err := jc.randomJunkWithSize(packetSize)
+ if err != nil {
+ jc.device.log.Errorf(
+ "%v - Failed to create junk packet: %v",
+ peer,
+ err,
+ )
+ return nil, err
+ }
+ junks = append(junks, junk)
+ }
+ return junks, nil
+}
+
+// Should be called with aSecMux RLocked
+func (jc *junkCreator) randomPacketSize() int {
+ return int(
+ jc.cha8Rand.Uint64()%uint64(
+ jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
+ ),
+ ) + jc.device.aSecCfg.junkPacketMinSize
+}
+
+// Should be called with aSecMux RLocked
+func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
+ headerJunk, err := jc.randomJunkWithSize(size)
+ if err != nil {
+ return fmt.Errorf("failed to create header junk: %v", err)
+ }
+ _, err = writer.Write(headerJunk)
+ if err != nil {
+ return fmt.Errorf("failed to write header junk: %v", err)
+ }
+ return nil
+}
+
+// Should be called with aSecMux RLocked
+func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
+ junk := make([]byte, size)
+ _, err := jc.cha8Rand.Read(junk)
+ return junk, err
+}
diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go
new file mode 100644
index 0000000..6f63360
--- /dev/null
+++ b/device/junk_creator_test.go
@@ -0,0 +1,124 @@
+package device
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
+ "github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
+)
+
+func setUpJunkCreator(t *testing.T) (junkCreator, error) {
+ cfg, _ := genASecurityConfigs(t)
+ tun := tuntest.NewChannelTUN()
+ binds := bindtest.NewChannelBinds()
+ level := LogLevelVerbose
+ dev := NewDevice(
+ tun.TUN(),
+ binds[0],
+ NewLogger(level, ""),
+ )
+
+ if err := dev.IpcSet(cfg[0]); err != nil {
+ t.Errorf("failed to configure device %v", err)
+ dev.Close()
+ return junkCreator{}, err
+ }
+
+ jc, err := NewJunkCreator(dev)
+
+ if err != nil {
+ t.Errorf("failed to create junk creator %v", err)
+ dev.Close()
+ return junkCreator{}, err
+ }
+
+ return jc, nil
+}
+
+func Test_junkCreator_createJunkPackets(t *testing.T) {
+ jc, err := setUpJunkCreator(t)
+ if err != nil {
+ return
+ }
+ t.Run("", func(t *testing.T) {
+ got, err := jc.createJunkPackets(nil)
+ if err != nil {
+ t.Errorf(
+ "junkCreator.createJunkPackets() = %v; failed",
+ err,
+ )
+ return
+ }
+ seen := make(map[string]bool)
+ for _, junk := range got {
+ key := string(junk)
+ if seen[key] {
+ t.Errorf(
+ "junkCreator.createJunkPackets() = %v, duplicate key: %v",
+ got,
+ junk,
+ )
+ return
+ }
+ seen[key] = true
+ }
+ })
+}
+
+func Test_junkCreator_randomJunkWithSize(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ jc, err := setUpJunkCreator(t)
+ if err != nil {
+ return
+ }
+ r1, _ := jc.randomJunkWithSize(10)
+ r2, _ := jc.randomJunkWithSize(10)
+ fmt.Printf("%v\n%v\n", r1, r2)
+ if bytes.Equal(r1, r2) {
+ t.Errorf("same junks %v", err)
+ jc.device.Close()
+ return
+ }
+ })
+}
+
+func Test_junkCreator_randomPacketSize(t *testing.T) {
+ jc, err := setUpJunkCreator(t)
+ if err != nil {
+ return
+ }
+ for range [30]struct{}{} {
+ t.Run("", func(t *testing.T) {
+ if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
+ got > jc.device.aSecCfg.junkPacketMaxSize {
+ t.Errorf(
+ "junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
+ got,
+ jc.device.aSecCfg.junkPacketMinSize,
+ jc.device.aSecCfg.junkPacketMaxSize,
+ )
+ }
+ })
+ }
+}
+
+func Test_junkCreator_appendJunk(t *testing.T) {
+ jc, err := setUpJunkCreator(t)
+ if err != nil {
+ return
+ }
+ t.Run("", func(t *testing.T) {
+ s := "apple"
+ buffer := bytes.NewBuffer([]byte(s))
+ err := jc.appendJunk(buffer, 30)
+ if err != nil &&
+ buffer.Len() != len(s)+30 {
+ t.Errorf("appendWithJunk() size don't match")
+ }
+ read := make([]byte, 50)
+ buffer.Read(read)
+ fmt.Println(string(read))
+ })
+}
diff --git a/device/send.go b/device/send.go
index 1b4406d..5c54d4d 100644
--- a/device/send.go
+++ b/device/send.go
@@ -9,7 +9,6 @@ import (
"bytes"
"encoding/binary"
"errors"
- "math/rand"
"net"
"os"
"sync"
@@ -121,7 +120,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
- peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
+ peer.device.log.Errorf(
+ "%v - Failed to create initiation message: %v",
+ peer,
+ err,
+ )
return err
}
var sendBuffer [][]byte
@@ -129,7 +132,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
- junks, err := peer.createJunkPackets()
+ junks, err := peer.device.junkCreator.createJunkPackets(peer)
peer.device.aSecMux.RUnlock()
if err != nil {
@@ -141,7 +144,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffers(junks)
if err != nil {
- peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
+ peer.device.log.Errorf(
+ "%v - Failed to send junk packets: %v",
+ peer,
+ err,
+ )
return err
}
}
@@ -150,7 +157,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
- err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
+ err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
@@ -175,7 +182,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffers(sendBuffer)
if err != nil {
- peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
+ peer.device.log.Errorf(
+ "%v - Failed to send handshake initiation: %v",
+ peer,
+ err,
+ )
}
peer.timersHandshakeInitiated()
@@ -191,7 +202,11 @@ func (peer *Peer) SendHandshakeResponse() error {
response, err := peer.device.CreateMessageResponse(peer)
if err != nil {
- peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
+ peer.device.log.Errorf(
+ "%v - Failed to create response message: %v",
+ peer,
+ err,
+ )
return err
}
var junkedHeader []byte
@@ -200,7 +215,7 @@ func (peer *Peer) SendHandshakeResponse() error {
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
- err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
+ err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
@@ -231,7 +246,11 @@ func (peer *Peer) SendHandshakeResponse() error {
// TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{junkedHeader})
if err != nil {
- peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
+ peer.device.log.Errorf(
+ "%v - Failed to send handshake response: %v",
+ peer,
+ err,
+ )
}
return err
}
@@ -239,7 +258,10 @@ func (peer *Peer) SendHandshakeResponse() error {
func (device *Device) SendHandshakeCookie(
initiatingElem *QueueHandshakeElement,
) error {
- device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
+ device.log.Verbosef(
+ "Sending cookie response for denied handshake message for %v",
+ initiatingElem.endpoint.DstToString(),
+ )
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(
@@ -266,7 +288,8 @@ func (peer *Peer) keepKeyFreshSending() {
return
}
nonce := keypair.sendNonce.Load()
- if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
+ if nonce > RekeyAfterMessages ||
+ (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
@@ -369,12 +392,18 @@ func (device *Device) RoutineReadFromTUN() {
// TODO: record stat for this
// This will happen if MSS is surprisingly small (< 576)
// coincident with reasonably high throughput.
- device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
+ device.log.Verbosef(
+ "Dropped some packets from multi-segment read: %v",
+ readErr,
+ )
continue
}
if !device.isClosed() {
if !errors.Is(readErr, os.ErrClosed) {
- device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
+ device.log.Errorf(
+ "Failed to read packet from TUN device: %v",
+ readErr,
+ )
}
go device.Close()
}
@@ -409,7 +438,8 @@ top:
}
keypair := peer.keypairs.Current()
- if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
+ if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages ||
+ time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false)
return
}
@@ -427,7 +457,10 @@ top:
if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
- elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
+ elemsContainerOOO.elems = append(
+ elemsContainerOOO.elems,
+ elem,
+ )
continue
} else {
elemsContainer.elems[i] = elem
@@ -440,7 +473,9 @@ top:
elemsContainer.elems = elemsContainer.elems[:i]
if elemsContainerOOO != nil {
- peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
+ peer.StagePackets(
+ elemsContainerOOO,
+ ) // XXX: Out of order, but we can't front-load go chans
}
if len(elemsContainer.elems) == 0 {
@@ -469,31 +504,6 @@ top:
}
}
-func (peer *Peer) createJunkPackets() ([][]byte, error) {
- if peer.device.aSecCfg.junkPacketCount == 0 {
- return nil, nil
- }
-
- junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
- for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
- packetSize := rand.Intn(
- peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
- ) + peer.device.aSecCfg.junkPacketMinSize
-
- junk, err := randomJunkWithSize(packetSize)
- if err != nil {
- peer.device.log.Errorf(
- "%v - Failed to create junk packet: %v",
- peer,
- err,
- )
- return nil, err
- }
- junks = append(junks, junk)
- }
- return junks, nil
-}
-
func (peer *Peer) FlushStagedPackets() {
for {
select {
@@ -546,11 +556,17 @@ func (device *Device) RoutineEncryption(id int) {
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
- binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
+ binary.LittleEndian.PutUint32(
+ fieldReceiver,
+ elem.keypair.remoteIndex,
+ )
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
- paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
+ paddingSize := calculatePaddingSize(
+ len(elem.packet),
+ int(device.tun.mtu.Load()),
+ )
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
@@ -570,7 +586,10 @@ func (device *Device) RoutineEncryption(id int) {
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
defer func() {
- defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
+ defer device.log.Verbosef(
+ "%v - Routine: sequential sender - stopped",
+ peer,
+ )
peer.stopping.Done()
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
diff --git a/main.go b/main.go
index 5a3dfef..77a379a 100644
--- a/main.go
+++ b/main.go
@@ -59,7 +59,12 @@ func warning() {
func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" {
- fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
+ fmt.Printf(
+ "amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n",
+ Version,
+ runtime.GOOS,
+ runtime.GOARCH,
+ )
return
}