From 71be0eb3a6547f172d17ce8b831b89e48052dc27 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Tue, 18 Mar 2025 08:34:23 +0100
Subject: [PATCH] faster and more secure junk creation
---
Dockerfile | 2 +-
device/device.go | 2 +
device/device_test.go | 6 +-
device/junk_creator.go | 69 ++++++++++++++++++++
device/junk_creator_test.go | 124 ++++++++++++++++++++++++++++++++++++
device/send.go | 32 +---------
device/util.go | 25 --------
device/util_test.go | 27 --------
go.mod | 8 +--
go.sum | 12 ++--
10 files changed, 212 insertions(+), 95 deletions(-)
create mode 100644 device/junk_creator.go
create mode 100644 device/junk_creator_test.go
delete mode 100644 device/util.go
delete mode 100644 device/util_test.go
diff --git a/Dockerfile b/Dockerfile
index 73016f7..12159be 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,4 +1,4 @@
-FROM golang:1.23.6 as awg
+FROM golang:1.24 as awg
COPY . /awg
WORKDIR /awg
RUN go mod download && \
diff --git a/device/device.go b/device/device.go
index 80e3793..1be15d0 100644
--- a/device/device.go
+++ b/device/device.go
@@ -95,6 +95,7 @@ type Device struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
+ junkCreator junkCreator
}
type aSecCfgType struct {
@@ -799,6 +800,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
device.isASecOn.SetTo(isASecOn)
+ device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
return err
diff --git a/device/device_test.go b/device/device_test.go
index e6664a6..d03610f 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -109,7 +109,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true",
"jc", "5",
"jmin", "500",
- "jmax", "501",
+ "jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
@@ -131,7 +131,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true",
"jc", "5",
"jmin", "500",
- "jmax", "501",
+ "jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
@@ -274,7 +274,7 @@ func TestTwoDevicePing(t *testing.T) {
})
}
-func TestTwoDevicePingASecurity(t *testing.T) {
+func TestASecurityTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
diff --git a/device/junk_creator.go b/device/junk_creator.go
new file mode 100644
index 0000000..3a2d3b4
--- /dev/null
+++ b/device/junk_creator.go
@@ -0,0 +1,69 @@
+package device
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "fmt"
+ v2 "math/rand/v2"
+)
+
+type junkCreator struct {
+ device *Device
+ cha8Rand *v2.ChaCha8
+}
+
+func NewJunkCreator(d *Device) (junkCreator, error) {
+ buf := make([]byte, 32)
+ _, err := crand.Read(buf)
+ if err != nil {
+ return junkCreator{}, err
+ }
+ return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
+}
+
+// Should be called with aSecMux RLocked
+func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
+ if jc.device.aSecCfg.junkPacketCount == 0 {
+ return nil, nil
+ }
+
+ junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
+ for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
+ packetSize := jc.randomPacketSize()
+ junk, err := jc.randomJunkWithSize(packetSize)
+ if err != nil {
+ return nil, fmt.Errorf("Failed to create junk packet: %v", err)
+ }
+ junks = append(junks, junk)
+ }
+ return junks, nil
+}
+
+// Should be called with aSecMux RLocked
+func (jc *junkCreator) randomPacketSize() int {
+ return int(
+ jc.cha8Rand.Uint64()%uint64(
+ jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
+ ),
+ ) + jc.device.aSecCfg.junkPacketMinSize
+}
+
+// Should be called with aSecMux RLocked
+func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
+ headerJunk, err := jc.randomJunkWithSize(size)
+ if err != nil {
+ return fmt.Errorf("failed to create header junk: %v", err)
+ }
+ _, err = writer.Write(headerJunk)
+ if err != nil {
+ return fmt.Errorf("failed to write header junk: %v", err)
+ }
+ return nil
+}
+
+// Should be called with aSecMux RLocked
+func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
+ junk := make([]byte, size)
+ _, err := jc.cha8Rand.Read(junk)
+ return junk, err
+}
diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go
new file mode 100644
index 0000000..d3cf2b3
--- /dev/null
+++ b/device/junk_creator_test.go
@@ -0,0 +1,124 @@
+package device
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
+ "github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
+)
+
+func setUpJunkCreator(t *testing.T) (junkCreator, error) {
+ cfg, _ := genASecurityConfigs(t)
+ tun := tuntest.NewChannelTUN()
+ binds := bindtest.NewChannelBinds()
+ level := LogLevelVerbose
+ dev := NewDevice(
+ tun.TUN(),
+ binds[0],
+ NewLogger(level, ""),
+ )
+
+ if err := dev.IpcSet(cfg[0]); err != nil {
+ t.Errorf("failed to configure device %v", err)
+ dev.Close()
+ return junkCreator{}, err
+ }
+
+ jc, err := NewJunkCreator(dev)
+
+ if err != nil {
+ t.Errorf("failed to create junk creator %v", err)
+ dev.Close()
+ return junkCreator{}, err
+ }
+
+ return jc, nil
+}
+
+func Test_junkCreator_createJunkPackets(t *testing.T) {
+ jc, err := setUpJunkCreator(t)
+ if err != nil {
+ return
+ }
+ t.Run("", func(t *testing.T) {
+ got, err := jc.createJunkPackets()
+ if err != nil {
+ t.Errorf(
+ "junkCreator.createJunkPackets() = %v; failed",
+ err,
+ )
+ return
+ }
+ seen := make(map[string]bool)
+ for _, junk := range got {
+ key := string(junk)
+ if seen[key] {
+ t.Errorf(
+ "junkCreator.createJunkPackets() = %v, duplicate key: %v",
+ got,
+ junk,
+ )
+ return
+ }
+ seen[key] = true
+ }
+ })
+}
+
+func Test_junkCreator_randomJunkWithSize(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ jc, err := setUpJunkCreator(t)
+ if err != nil {
+ return
+ }
+ r1, _ := jc.randomJunkWithSize(10)
+ r2, _ := jc.randomJunkWithSize(10)
+ fmt.Printf("%v\n%v\n", r1, r2)
+ if bytes.Equal(r1, r2) {
+ t.Errorf("same junks %v", err)
+ jc.device.Close()
+ return
+ }
+ })
+}
+
+func Test_junkCreator_randomPacketSize(t *testing.T) {
+ jc, err := setUpJunkCreator(t)
+ if err != nil {
+ return
+ }
+ for range [30]struct{}{} {
+ t.Run("", func(t *testing.T) {
+ if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
+ got > jc.device.aSecCfg.junkPacketMaxSize {
+ t.Errorf(
+ "junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
+ got,
+ jc.device.aSecCfg.junkPacketMinSize,
+ jc.device.aSecCfg.junkPacketMaxSize,
+ )
+ }
+ })
+ }
+}
+
+func Test_junkCreator_appendJunk(t *testing.T) {
+ jc, err := setUpJunkCreator(t)
+ if err != nil {
+ return
+ }
+ t.Run("", func(t *testing.T) {
+ s := "apple"
+ buffer := bytes.NewBuffer([]byte(s))
+ err := jc.appendJunk(buffer, 30)
+ if err != nil &&
+ buffer.Len() != len(s)+30 {
+ t.Errorf("appendWithJunk() size don't match")
+ }
+ read := make([]byte, 50)
+ buffer.Read(read)
+ fmt.Println(string(read))
+ })
+}
diff --git a/device/send.go b/device/send.go
index 1b4406d..7eca099 100644
--- a/device/send.go
+++ b/device/send.go
@@ -9,7 +9,6 @@ import (
"bytes"
"encoding/binary"
"errors"
- "math/rand"
"net"
"os"
"sync"
@@ -129,7 +128,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
- junks, err := peer.createJunkPackets()
+ junks, err := peer.device.junkCreator.createJunkPackets()
peer.device.aSecMux.RUnlock()
if err != nil {
@@ -150,7 +149,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
- err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
+ err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
@@ -200,7 +199,7 @@ func (peer *Peer) SendHandshakeResponse() error {
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
- err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
+ err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
@@ -469,31 +468,6 @@ top:
}
}
-func (peer *Peer) createJunkPackets() ([][]byte, error) {
- if peer.device.aSecCfg.junkPacketCount == 0 {
- return nil, nil
- }
-
- junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
- for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
- packetSize := rand.Intn(
- peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
- ) + peer.device.aSecCfg.junkPacketMinSize
-
- junk, err := randomJunkWithSize(packetSize)
- if err != nil {
- peer.device.log.Errorf(
- "%v - Failed to create junk packet: %v",
- peer,
- err,
- )
- return nil, err
- }
- junks = append(junks, junk)
- }
- return junks, nil
-}
-
func (peer *Peer) FlushStagedPackets() {
for {
select {
diff --git a/device/util.go b/device/util.go
deleted file mode 100644
index aab8ab7..0000000
--- a/device/util.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package device
-
-import (
- "bytes"
- crand "crypto/rand"
- "fmt"
-)
-
-func appendJunk(writer *bytes.Buffer, size int) error {
- headerJunk, err := randomJunkWithSize(size)
- if err != nil {
- return fmt.Errorf("failed to create header junk: %v", err)
- }
- _, err = writer.Write(headerJunk)
- if err != nil {
- return fmt.Errorf("failed to write header junk: %v", err)
- }
- return nil
-}
-
-func randomJunkWithSize(size int) ([]byte, error) {
- junk := make([]byte, size)
- _, err := crand.Read(junk)
- return junk, err
-}
diff --git a/device/util_test.go b/device/util_test.go
deleted file mode 100644
index c061eef..0000000
--- a/device/util_test.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package device
-
-import (
- "bytes"
- "fmt"
- "testing"
-)
-
-func Test_randomJunktWithSize(t *testing.T) {
- junk, err := randomJunkWithSize(30)
- fmt.Println(string(junk), len(junk), err)
-}
-
-func Test_appendJunk(t *testing.T) {
- t.Run("", func(t *testing.T) {
- s := "apple"
- buffer := bytes.NewBuffer([]byte(s))
- err := appendJunk(buffer, 30)
- if err != nil &&
- buffer.Len() != len(s)+30 {
- t.Errorf("appendWithJunk() size don't match")
- }
- read := make([]byte, 50)
- buffer.Read(read)
- fmt.Println(string(read))
- })
-}
diff --git a/go.mod b/go.mod
index 4575bc8..608969f 100644
--- a/go.mod
+++ b/go.mod
@@ -1,12 +1,12 @@
module github.com/amnezia-vpn/amneziawg-go
-go 1.23.6
+go 1.24
require (
github.com/tevino/abool/v2 v2.1.0
- golang.org/x/crypto v0.32.0
- golang.org/x/net v0.34.0
- golang.org/x/sys v0.29.0
+ golang.org/x/crypto v0.36.0
+ golang.org/x/net v0.37.0
+ golang.org/x/sys v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
)
diff --git a/go.sum b/go.sum
index 10f1f2a..497f949 100644
--- a/go.sum
+++ b/go.sum
@@ -4,14 +4,14 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
-golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
-golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
+golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
+golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
-golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
-golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
-golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
-golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
+golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
+golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
+golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=