AWG-2 create&integrate junk_creator

This commit is contained in:
Mark Puha 2025-02-01 16:46:38 +01:00
parent 9a56c052cc
commit 971144c9fb
6 changed files with 294 additions and 54 deletions

View file

@ -95,6 +95,8 @@ type Device struct {
isASecOn abool.AtomicBool isASecOn abool.AtomicBool
aSecMux sync.RWMutex aSecMux sync.RWMutex
aSecCfg aSecCfgType aSecCfg aSecCfgType
junkCreator junkCreator
} }
type aSecCfgType struct { type aSecCfgType struct {
@ -161,7 +163,10 @@ func (device *Device) changeState(want deviceState) (err error) {
old := device.deviceState() old := device.deviceState()
if old == deviceStateClosed { if old == deviceStateClosed {
// once closed, always closed // once closed, always closed
device.log.Verbosef("Interface closed, ignored requested state %s", want) device.log.Verbosef(
"Interface closed, ignored requested state %s",
want,
)
return nil return nil
} }
switch want { switch want {
@ -182,7 +187,11 @@ func (device *Device) changeState(want deviceState) (err error) {
} }
} }
device.log.Verbosef( device.log.Verbosef(
"Interface state was %s, requested %s, now %s", old, want, device.deviceState()) "Interface state was %s, requested %s, now %s",
old,
want,
device.deviceState(),
)
return return
} }
@ -287,7 +296,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(
handshake.remoteStatic,
)
expiredPeers = append(expiredPeers, peer) expiredPeers = append(expiredPeers, peer)
} }
@ -433,7 +444,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.keypairs.RLock() peer.keypairs.RLock()
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) sendKeepalive := peer.keypairs.current != nil &&
!peer.keypairs.current.created.Add(RejectAfterTime).
Before(time.Now())
peer.keypairs.RUnlock() peer.keypairs.RUnlock()
if sendKeepalive { if sendKeepalive {
peer.SendKeepalive() peer.SendKeepalive()
@ -539,8 +552,12 @@ func (device *Device) BindUpdate() error {
// start receiving routines // start receiving routines
device.net.stopping.Add(len(recvFns)) device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.decryption.wg.Add(
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake len(recvFns),
) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(
len(recvFns),
) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize() batchSize := netc.bind.BatchSize()
for _, fn := range recvFns { for _, fn := range recvFns {
go device.RoutineReceiveIncoming(batchSize, fn) go device.RoutineReceiveIncoming(batchSize, fn)
@ -569,7 +586,6 @@ func (device *Device) resetProtocol() {
} }
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet { if !tempASecCfg.isSet {
return err return err
} }
@ -799,6 +815,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
} }
device.isASecOn.SetTo(isASecOn) device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock() device.aSecMux.Unlock()
return err return err

View file

@ -109,7 +109,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true", "replace_peers", "true",
"jc", "5", "jc", "5",
"jmin", "500", "jmin", "500",
"jmax", "501", "jmax", "1000",
"s1", "30", "s1", "30",
"s2", "40", "s2", "40",
"h1", "123456", "h1", "123456",
@ -131,7 +131,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true", "replace_peers", "true",
"jc", "5", "jc", "5",
"jmin", "500", "jmin", "500",
"jmax", "501", "jmax", "1000",
"s1", "30", "s1", "30",
"s2", "40", "s2", "40",
"h1", "123456", "h1", "123456",
@ -274,6 +274,7 @@ func TestTwoDevicePing(t *testing.T) {
}) })
} }
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestTwoDevicePingASecurity(t *testing.T) { func TestTwoDevicePingASecurity(t *testing.T) {
goroutineLeakCheck(t) goroutineLeakCheck(t)
pair := genTestPair(t, true, true) pair := genTestPair(t, true, true)

74
device/junk_creator.go Normal file
View file

@ -0,0 +1,74 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
device *Device
cha8Rand *v2.ChaCha8
}
func NewJunkCreator(d *Device) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
jc.device.log.Errorf(
"%v - Failed to create junk packet: %v",
peer,
err,
)
return nil, err
}
junks = append(junks, junk)
}
return junks, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

124
device/junk_creator_test.go Normal file
View file

@ -0,0 +1,124 @@
package device
import (
"bytes"
"fmt"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
cfg, _ := genASecurityConfigs(t)
tun := tuntest.NewChannelTUN()
binds := bindtest.NewChannelBinds()
level := LogLevelVerbose
dev := NewDevice(
tun.TUN(),
binds[0],
NewLogger(level, ""),
)
if err := dev.IpcSet(cfg[0]); err != nil {
t.Errorf("failed to configure device %v", err)
dev.Close()
return junkCreator{}, err
}
jc, err := NewJunkCreator(dev)
if err != nil {
t.Errorf("failed to create junk creator %v", err)
dev.Close()
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
got, err := jc.createJunkPackets(nil)
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
jc.device.Close()
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -9,7 +9,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"math/rand"
"net" "net"
"os" "os"
"sync" "sync"
@ -121,7 +120,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
msg, err := peer.device.CreateMessageInitiation(peer) msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) peer.device.log.Errorf(
"%v - Failed to create initiation message: %v",
peer,
err,
)
return err return err
} }
var sendBuffer [][]byte var sendBuffer [][]byte
@ -129,7 +132,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
var junkedHeader []byte var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() { if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock() peer.device.aSecMux.RLock()
junks, err := peer.createJunkPackets() junks, err := peer.device.junkCreator.createJunkPackets(peer)
peer.device.aSecMux.RUnlock() peer.device.aSecMux.RUnlock()
if err != nil { if err != nil {
@ -141,7 +144,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffers(junks) err = peer.SendBuffers(junks)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err) peer.device.log.Errorf(
"%v - Failed to send junk packets: %v",
peer,
err,
)
return err return err
} }
} }
@ -150,7 +157,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if peer.device.aSecCfg.initPacketJunkSize != 0 { if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock() peer.device.aSecMux.RUnlock()
@ -175,7 +182,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffers(sendBuffer) err = peer.SendBuffers(sendBuffer)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) peer.device.log.Errorf(
"%v - Failed to send handshake initiation: %v",
peer,
err,
)
} }
peer.timersHandshakeInitiated() peer.timersHandshakeInitiated()
@ -191,7 +202,11 @@ func (peer *Peer) SendHandshakeResponse() error {
response, err := peer.device.CreateMessageResponse(peer) response, err := peer.device.CreateMessageResponse(peer)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) peer.device.log.Errorf(
"%v - Failed to create response message: %v",
peer,
err,
)
return err return err
} }
var junkedHeader []byte var junkedHeader []byte
@ -200,7 +215,7 @@ func (peer *Peer) SendHandshakeResponse() error {
if peer.device.aSecCfg.responsePacketJunkSize != 0 { if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil { if err != nil {
peer.device.aSecMux.RUnlock() peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
@ -231,7 +246,11 @@ func (peer *Peer) SendHandshakeResponse() error {
// TODO: allocation could be avoided // TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{junkedHeader}) err = peer.SendBuffers([][]byte{junkedHeader})
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) peer.device.log.Errorf(
"%v - Failed to send handshake response: %v",
peer,
err,
)
} }
return err return err
} }
@ -239,7 +258,10 @@ func (peer *Peer) SendHandshakeResponse() error {
func (device *Device) SendHandshakeCookie( func (device *Device) SendHandshakeCookie(
initiatingElem *QueueHandshakeElement, initiatingElem *QueueHandshakeElement,
) error { ) error {
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) device.log.Verbosef(
"Sending cookie response for denied handshake message for %v",
initiatingElem.endpoint.DstToString(),
)
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply( reply, err := device.cookieChecker.CreateReply(
@ -266,7 +288,8 @@ func (peer *Peer) keepKeyFreshSending() {
return return
} }
nonce := keypair.sendNonce.Load() nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { if nonce > RekeyAfterMessages ||
(keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
@ -369,12 +392,18 @@ func (device *Device) RoutineReadFromTUN() {
// TODO: record stat for this // TODO: record stat for this
// This will happen if MSS is surprisingly small (< 576) // This will happen if MSS is surprisingly small (< 576)
// coincident with reasonably high throughput. // coincident with reasonably high throughput.
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) device.log.Verbosef(
"Dropped some packets from multi-segment read: %v",
readErr,
)
continue continue
} }
if !device.isClosed() { if !device.isClosed() {
if !errors.Is(readErr, os.ErrClosed) { if !errors.Is(readErr, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", readErr) device.log.Errorf(
"Failed to read packet from TUN device: %v",
readErr,
)
} }
go device.Close() go device.Close()
} }
@ -409,7 +438,8 @@ top:
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages ||
time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
return return
} }
@ -427,7 +457,10 @@ top:
if elemsContainerOOO == nil { if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer() elemsContainerOOO = peer.device.GetOutboundElementsContainer()
} }
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) elemsContainerOOO.elems = append(
elemsContainerOOO.elems,
elem,
)
continue continue
} else { } else {
elemsContainer.elems[i] = elem elemsContainer.elems[i] = elem
@ -440,7 +473,9 @@ top:
elemsContainer.elems = elemsContainer.elems[:i] elemsContainer.elems = elemsContainer.elems[:i]
if elemsContainerOOO != nil { if elemsContainerOOO != nil {
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans peer.StagePackets(
elemsContainerOOO,
) // XXX: Out of order, but we can't front-load go chans
} }
if len(elemsContainer.elems) == 0 { if len(elemsContainer.elems) == 0 {
@ -469,31 +504,6 @@ top:
} }
} }
func (peer *Peer) createJunkPackets() ([][]byte, error) {
if peer.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
packetSize := rand.Intn(
peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
) + peer.device.aSecCfg.junkPacketMinSize
junk, err := randomJunkWithSize(packetSize)
if err != nil {
peer.device.log.Errorf(
"%v - Failed to create junk packet: %v",
peer,
err,
)
return nil, err
}
junks = append(junks, junk)
}
return junks, nil
}
func (peer *Peer) FlushStagedPackets() { func (peer *Peer) FlushStagedPackets() {
for { for {
select { select {
@ -546,11 +556,17 @@ func (device *Device) RoutineEncryption(id int) {
fieldNonce := header[8:16] fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType) binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) binary.LittleEndian.PutUint32(
fieldReceiver,
elem.keypair.remoteIndex,
)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16 // pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) paddingSize := calculatePaddingSize(
len(elem.packet),
int(device.tun.mtu.Load()),
)
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer // encrypt content and release to consumer
@ -570,7 +586,10 @@ func (device *Device) RoutineEncryption(id int) {
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device device := peer.device
defer func() { defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) defer device.log.Verbosef(
"%v - Routine: sequential sender - stopped",
peer,
)
peer.stopping.Done() peer.stopping.Done()
}() }()
device.log.Verbosef("%v - Routine: sequential sender - started", peer) device.log.Verbosef("%v - Routine: sequential sender - started", peer)

View file

@ -59,7 +59,12 @@ func warning() {
func main() { func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" { if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH) fmt.Printf(
"amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n",
Version,
runtime.GOOS,
runtime.GOARCH,
)
return return
} }