From 71be0eb3a6547f172d17ce8b831b89e48052dc27 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Tue, 18 Mar 2025 08:34:23 +0100 Subject: [PATCH] faster and more secure junk creation --- Dockerfile | 2 +- device/device.go | 2 + device/device_test.go | 6 +- device/junk_creator.go | 69 ++++++++++++++++++++ device/junk_creator_test.go | 124 ++++++++++++++++++++++++++++++++++++ device/send.go | 32 +--------- device/util.go | 25 -------- device/util_test.go | 27 -------- go.mod | 8 +-- go.sum | 12 ++-- 10 files changed, 212 insertions(+), 95 deletions(-) create mode 100644 device/junk_creator.go create mode 100644 device/junk_creator_test.go delete mode 100644 device/util.go delete mode 100644 device/util_test.go diff --git a/Dockerfile b/Dockerfile index 73016f7..12159be 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.23.6 as awg +FROM golang:1.24 as awg COPY . /awg WORKDIR /awg RUN go mod download && \ diff --git a/device/device.go b/device/device.go index 80e3793..1be15d0 100644 --- a/device/device.go +++ b/device/device.go @@ -95,6 +95,7 @@ type Device struct { isASecOn abool.AtomicBool aSecMux sync.RWMutex aSecCfg aSecCfgType + junkCreator junkCreator } type aSecCfgType struct { @@ -799,6 +800,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { } device.isASecOn.SetTo(isASecOn) + device.junkCreator, err = NewJunkCreator(device) device.aSecMux.Unlock() return err diff --git a/device/device_test.go b/device/device_test.go index e6664a6..d03610f 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -109,7 +109,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "replace_peers", "true", "jc", "5", "jmin", "500", - "jmax", "501", + "jmax", "1000", "s1", "30", "s2", "40", "h1", "123456", @@ -131,7 +131,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "replace_peers", "true", "jc", "5", "jmin", "500", - "jmax", "501", + "jmax", "1000", "s1", "30", "s2", "40", "h1", "123456", @@ -274,7 +274,7 @@ func TestTwoDevicePing(t *testing.T) { }) } -func TestTwoDevicePingASecurity(t *testing.T) { +func TestASecurityTwoDevicePing(t *testing.T) { goroutineLeakCheck(t) pair := genTestPair(t, true, true) t.Run("ping 1.0.0.1", func(t *testing.T) { diff --git a/device/junk_creator.go b/device/junk_creator.go new file mode 100644 index 0000000..3a2d3b4 --- /dev/null +++ b/device/junk_creator.go @@ -0,0 +1,69 @@ +package device + +import ( + "bytes" + crand "crypto/rand" + "fmt" + v2 "math/rand/v2" +) + +type junkCreator struct { + device *Device + cha8Rand *v2.ChaCha8 +} + +func NewJunkCreator(d *Device) (junkCreator, error) { + buf := make([]byte, 32) + _, err := crand.Read(buf) + if err != nil { + return junkCreator{}, err + } + return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil +} + +// Should be called with aSecMux RLocked +func (jc *junkCreator) createJunkPackets() ([][]byte, error) { + if jc.device.aSecCfg.junkPacketCount == 0 { + return nil, nil + } + + junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) + for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { + packetSize := jc.randomPacketSize() + junk, err := jc.randomJunkWithSize(packetSize) + if err != nil { + return nil, fmt.Errorf("Failed to create junk packet: %v", err) + } + junks = append(junks, junk) + } + return junks, nil +} + +// Should be called with aSecMux RLocked +func (jc *junkCreator) randomPacketSize() int { + return int( + jc.cha8Rand.Uint64()%uint64( + jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, + ), + ) + jc.device.aSecCfg.junkPacketMinSize +} + +// Should be called with aSecMux RLocked +func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { + headerJunk, err := jc.randomJunkWithSize(size) + if err != nil { + return fmt.Errorf("failed to create header junk: %v", err) + } + _, err = writer.Write(headerJunk) + if err != nil { + return fmt.Errorf("failed to write header junk: %v", err) + } + return nil +} + +// Should be called with aSecMux RLocked +func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { + junk := make([]byte, size) + _, err := jc.cha8Rand.Read(junk) + return junk, err +} diff --git a/device/junk_creator_test.go b/device/junk_creator_test.go new file mode 100644 index 0000000..d3cf2b3 --- /dev/null +++ b/device/junk_creator_test.go @@ -0,0 +1,124 @@ +package device + +import ( + "bytes" + "fmt" + "testing" + + "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" + "github.com/amnezia-vpn/amneziawg-go/tun/tuntest" +) + +func setUpJunkCreator(t *testing.T) (junkCreator, error) { + cfg, _ := genASecurityConfigs(t) + tun := tuntest.NewChannelTUN() + binds := bindtest.NewChannelBinds() + level := LogLevelVerbose + dev := NewDevice( + tun.TUN(), + binds[0], + NewLogger(level, ""), + ) + + if err := dev.IpcSet(cfg[0]); err != nil { + t.Errorf("failed to configure device %v", err) + dev.Close() + return junkCreator{}, err + } + + jc, err := NewJunkCreator(dev) + + if err != nil { + t.Errorf("failed to create junk creator %v", err) + dev.Close() + return junkCreator{}, err + } + + return jc, nil +} + +func Test_junkCreator_createJunkPackets(t *testing.T) { + jc, err := setUpJunkCreator(t) + if err != nil { + return + } + t.Run("", func(t *testing.T) { + got, err := jc.createJunkPackets() + if err != nil { + t.Errorf( + "junkCreator.createJunkPackets() = %v; failed", + err, + ) + return + } + seen := make(map[string]bool) + for _, junk := range got { + key := string(junk) + if seen[key] { + t.Errorf( + "junkCreator.createJunkPackets() = %v, duplicate key: %v", + got, + junk, + ) + return + } + seen[key] = true + } + }) +} + +func Test_junkCreator_randomJunkWithSize(t *testing.T) { + t.Run("", func(t *testing.T) { + jc, err := setUpJunkCreator(t) + if err != nil { + return + } + r1, _ := jc.randomJunkWithSize(10) + r2, _ := jc.randomJunkWithSize(10) + fmt.Printf("%v\n%v\n", r1, r2) + if bytes.Equal(r1, r2) { + t.Errorf("same junks %v", err) + jc.device.Close() + return + } + }) +} + +func Test_junkCreator_randomPacketSize(t *testing.T) { + jc, err := setUpJunkCreator(t) + if err != nil { + return + } + for range [30]struct{}{} { + t.Run("", func(t *testing.T) { + if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || + got > jc.device.aSecCfg.junkPacketMaxSize { + t.Errorf( + "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", + got, + jc.device.aSecCfg.junkPacketMinSize, + jc.device.aSecCfg.junkPacketMaxSize, + ) + } + }) + } +} + +func Test_junkCreator_appendJunk(t *testing.T) { + jc, err := setUpJunkCreator(t) + if err != nil { + return + } + t.Run("", func(t *testing.T) { + s := "apple" + buffer := bytes.NewBuffer([]byte(s)) + err := jc.appendJunk(buffer, 30) + if err != nil && + buffer.Len() != len(s)+30 { + t.Errorf("appendWithJunk() size don't match") + } + read := make([]byte, 50) + buffer.Read(read) + fmt.Println(string(read)) + }) +} diff --git a/device/send.go b/device/send.go index 1b4406d..7eca099 100644 --- a/device/send.go +++ b/device/send.go @@ -9,7 +9,6 @@ import ( "bytes" "encoding/binary" "errors" - "math/rand" "net" "os" "sync" @@ -129,7 +128,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { peer.device.aSecMux.RLock() - junks, err := peer.createJunkPackets() + junks, err := peer.device.junkCreator.createJunkPackets() peer.device.aSecMux.RUnlock() if err != nil { @@ -150,7 +149,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if peer.device.aSecCfg.initPacketJunkSize != 0 { buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) + err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) peer.device.aSecMux.RUnlock() @@ -200,7 +199,7 @@ func (peer *Peer) SendHandshakeResponse() error { if peer.device.aSecCfg.responsePacketJunkSize != 0 { buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) + err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) if err != nil { peer.device.aSecMux.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) @@ -469,31 +468,6 @@ top: } } -func (peer *Peer) createJunkPackets() ([][]byte, error) { - if peer.device.aSecCfg.junkPacketCount == 0 { - return nil, nil - } - - junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount) - for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ { - packetSize := rand.Intn( - peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize, - ) + peer.device.aSecCfg.junkPacketMinSize - - junk, err := randomJunkWithSize(packetSize) - if err != nil { - peer.device.log.Errorf( - "%v - Failed to create junk packet: %v", - peer, - err, - ) - return nil, err - } - junks = append(junks, junk) - } - return junks, nil -} - func (peer *Peer) FlushStagedPackets() { for { select { diff --git a/device/util.go b/device/util.go deleted file mode 100644 index aab8ab7..0000000 --- a/device/util.go +++ /dev/null @@ -1,25 +0,0 @@ -package device - -import ( - "bytes" - crand "crypto/rand" - "fmt" -) - -func appendJunk(writer *bytes.Buffer, size int) error { - headerJunk, err := randomJunkWithSize(size) - if err != nil { - return fmt.Errorf("failed to create header junk: %v", err) - } - _, err = writer.Write(headerJunk) - if err != nil { - return fmt.Errorf("failed to write header junk: %v", err) - } - return nil -} - -func randomJunkWithSize(size int) ([]byte, error) { - junk := make([]byte, size) - _, err := crand.Read(junk) - return junk, err -} diff --git a/device/util_test.go b/device/util_test.go deleted file mode 100644 index c061eef..0000000 --- a/device/util_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package device - -import ( - "bytes" - "fmt" - "testing" -) - -func Test_randomJunktWithSize(t *testing.T) { - junk, err := randomJunkWithSize(30) - fmt.Println(string(junk), len(junk), err) -} - -func Test_appendJunk(t *testing.T) { - t.Run("", func(t *testing.T) { - s := "apple" - buffer := bytes.NewBuffer([]byte(s)) - err := appendJunk(buffer, 30) - if err != nil && - buffer.Len() != len(s)+30 { - t.Errorf("appendWithJunk() size don't match") - } - read := make([]byte, 50) - buffer.Read(read) - fmt.Println(string(read)) - }) -} diff --git a/go.mod b/go.mod index 4575bc8..608969f 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module github.com/amnezia-vpn/amneziawg-go -go 1.23.6 +go 1.24 require ( github.com/tevino/abool/v2 v2.1.0 - golang.org/x/crypto v0.32.0 - golang.org/x/net v0.34.0 - golang.org/x/sys v0.29.0 + golang.org/x/crypto v0.36.0 + golang.org/x/net v0.37.0 + golang.org/x/sys v0.31.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 ) diff --git a/go.sum b/go.sum index 10f1f2a..497f949 100644 --- a/go.sum +++ b/go.sum @@ -4,14 +4,14 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= -golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= -golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=