Merge pull request #1 from marko1777/wg-2.0

Wg 2.0
This commit is contained in:
Mark Puha 2025-02-01 16:54:07 +01:00 committed by GitHub
commit 65215f0ed3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 295 additions and 107 deletions

View file

@ -95,6 +95,8 @@ type Device struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
}
type aSecCfgType struct {
@ -161,7 +163,10 @@ func (device *Device) changeState(want deviceState) (err error) {
old := device.deviceState()
if old == deviceStateClosed {
// once closed, always closed
device.log.Verbosef("Interface closed, ignored requested state %s", want)
device.log.Verbosef(
"Interface closed, ignored requested state %s",
want,
)
return nil
}
switch want {
@ -182,7 +187,11 @@ func (device *Device) changeState(want deviceState) (err error) {
}
}
device.log.Verbosef(
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
"Interface state was %s, requested %s, now %s",
old,
want,
device.deviceState(),
)
return
}
@ -287,7 +296,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(
handshake.remoteStatic,
)
expiredPeers = append(expiredPeers, peer)
}
@ -433,7 +444,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.keypairs.RLock()
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
sendKeepalive := peer.keypairs.current != nil &&
!peer.keypairs.current.created.Add(RejectAfterTime).
Before(time.Now())
peer.keypairs.RUnlock()
if sendKeepalive {
peer.SendKeepalive()
@ -539,8 +552,12 @@ func (device *Device) BindUpdate() error {
// start receiving routines
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
device.queue.decryption.wg.Add(
len(recvFns),
) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(
len(recvFns),
) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
go device.RoutineReceiveIncoming(batchSize, fn)
@ -569,7 +586,6 @@ func (device *Device) resetProtocol() {
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
if !tempASecCfg.isSet {
return err
}
@ -799,6 +815,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
}
device.isASecOn.SetTo(isASecOn)
device.junkCreator, err = NewJunkCreator(device)
device.aSecMux.Unlock()
return err

View file

@ -109,7 +109,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "501",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
@ -131,7 +131,7 @@ func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "501",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
@ -274,6 +274,7 @@ func TestTwoDevicePing(t *testing.T) {
})
}
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestTwoDevicePingASecurity(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)

74
device/junk_creator.go Normal file
View file

@ -0,0 +1,74 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
v2 "math/rand/v2"
)
type junkCreator struct {
device *Device
cha8Rand *v2.ChaCha8
}
func NewJunkCreator(d *Device) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets(peer *Peer) ([][]byte, error) {
if jc.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize)
if err != nil {
jc.device.log.Errorf(
"%v - Failed to create junk packet: %v",
peer,
err,
)
return nil, err
}
junks = append(junks, junk)
}
return junks, nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int {
return int(
jc.cha8Rand.Uint64()%uint64(
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
),
) + jc.device.aSecCfg.junkPacketMinSize
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
// Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
}

124
device/junk_creator_test.go Normal file
View file

@ -0,0 +1,124 @@
package device
import (
"bytes"
"fmt"
"testing"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
)
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
cfg, _ := genASecurityConfigs(t)
tun := tuntest.NewChannelTUN()
binds := bindtest.NewChannelBinds()
level := LogLevelVerbose
dev := NewDevice(
tun.TUN(),
binds[0],
NewLogger(level, ""),
)
if err := dev.IpcSet(cfg[0]); err != nil {
t.Errorf("failed to configure device %v", err)
dev.Close()
return junkCreator{}, err
}
jc, err := NewJunkCreator(dev)
if err != nil {
t.Errorf("failed to create junk creator %v", err)
dev.Close()
return junkCreator{}, err
}
return jc, nil
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
got, err := jc.createJunkPackets(nil)
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("", func(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err)
jc.device.Close()
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
for range [30]struct{}{} {
t.Run("", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
got > jc.device.aSecCfg.junkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.device.aSecCfg.junkPacketMinSize,
jc.device.aSecCfg.junkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t)
if err != nil {
return
}
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View file

@ -9,7 +9,6 @@ import (
"bytes"
"encoding/binary"
"errors"
"math/rand"
"net"
"os"
"sync"
@ -121,7 +120,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
peer.device.log.Errorf(
"%v - Failed to create initiation message: %v",
peer,
err,
)
return err
}
var sendBuffer [][]byte
@ -129,7 +132,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
junks, err := peer.createJunkPackets()
junks, err := peer.device.junkCreator.createJunkPackets(peer)
peer.device.aSecMux.RUnlock()
if err != nil {
@ -141,7 +144,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffers(junks)
if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
peer.device.log.Errorf(
"%v - Failed to send junk packets: %v",
peer,
err,
)
return err
}
}
@ -150,7 +157,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if peer.device.aSecCfg.initPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.aSecMux.RUnlock()
@ -175,7 +182,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffers(sendBuffer)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
peer.device.log.Errorf(
"%v - Failed to send handshake initiation: %v",
peer,
err,
)
}
peer.timersHandshakeInitiated()
@ -191,7 +202,11 @@ func (peer *Peer) SendHandshakeResponse() error {
response, err := peer.device.CreateMessageResponse(peer)
if err != nil {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
peer.device.log.Errorf(
"%v - Failed to create response message: %v",
peer,
err,
)
return err
}
var junkedHeader []byte
@ -200,7 +215,7 @@ func (peer *Peer) SendHandshakeResponse() error {
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
@ -231,7 +246,11 @@ func (peer *Peer) SendHandshakeResponse() error {
// TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{junkedHeader})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
peer.device.log.Errorf(
"%v - Failed to send handshake response: %v",
peer,
err,
)
}
return err
}
@ -239,7 +258,10 @@ func (peer *Peer) SendHandshakeResponse() error {
func (device *Device) SendHandshakeCookie(
initiatingElem *QueueHandshakeElement,
) error {
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
device.log.Verbosef(
"Sending cookie response for denied handshake message for %v",
initiatingElem.endpoint.DstToString(),
)
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(
@ -266,7 +288,8 @@ func (peer *Peer) keepKeyFreshSending() {
return
}
nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
if nonce > RekeyAfterMessages ||
(keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
@ -369,12 +392,18 @@ func (device *Device) RoutineReadFromTUN() {
// TODO: record stat for this
// This will happen if MSS is surprisingly small (< 576)
// coincident with reasonably high throughput.
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
device.log.Verbosef(
"Dropped some packets from multi-segment read: %v",
readErr,
)
continue
}
if !device.isClosed() {
if !errors.Is(readErr, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
device.log.Errorf(
"Failed to read packet from TUN device: %v",
readErr,
)
}
go device.Close()
}
@ -409,7 +438,8 @@ top:
}
keypair := peer.keypairs.Current()
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages ||
time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false)
return
}
@ -427,7 +457,10 @@ top:
if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
elemsContainerOOO.elems = append(
elemsContainerOOO.elems,
elem,
)
continue
} else {
elemsContainer.elems[i] = elem
@ -440,7 +473,9 @@ top:
elemsContainer.elems = elemsContainer.elems[:i]
if elemsContainerOOO != nil {
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
peer.StagePackets(
elemsContainerOOO,
) // XXX: Out of order, but we can't front-load go chans
}
if len(elemsContainer.elems) == 0 {
@ -469,31 +504,6 @@ top:
}
}
func (peer *Peer) createJunkPackets() ([][]byte, error) {
if peer.device.aSecCfg.junkPacketCount == 0 {
return nil, nil
}
junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
packetSize := rand.Intn(
peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
) + peer.device.aSecCfg.junkPacketMinSize
junk, err := randomJunkWithSize(packetSize)
if err != nil {
peer.device.log.Errorf(
"%v - Failed to create junk packet: %v",
peer,
err,
)
return nil, err
}
junks = append(junks, junk)
}
return junks, nil
}
func (peer *Peer) FlushStagedPackets() {
for {
select {
@ -546,11 +556,17 @@ func (device *Device) RoutineEncryption(id int) {
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint32(
fieldReceiver,
elem.keypair.remoteIndex,
)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
paddingSize := calculatePaddingSize(
len(elem.packet),
int(device.tun.mtu.Load()),
)
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
@ -570,7 +586,10 @@ func (device *Device) RoutineEncryption(id int) {
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
defer device.log.Verbosef(
"%v - Routine: sequential sender - stopped",
peer,
)
peer.stopping.Done()
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)

View file

@ -1,25 +0,0 @@
package device
import (
"bytes"
crand "crypto/rand"
"fmt"
)
func appendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := randomJunkWithSize(size)
if err != nil {
return fmt.Errorf("failed to create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("failed to write header junk: %v", err)
}
return nil
}
func randomJunkWithSize(size int) ([]byte, error) {
junk := make([]byte, size)
_, err := crand.Read(junk)
return junk, err
}

View file

@ -1,27 +0,0 @@
package device
import (
"bytes"
"fmt"
"testing"
)
func Test_randomJunktWithSize(t *testing.T) {
junk, err := randomJunkWithSize(30)
fmt.Println(string(junk), len(junk), err)
}
func Test_appendJunk(t *testing.T) {
t.Run("", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := appendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/amnezia-vpn/amneziawg-go
go 1.22.3
go 1.23
require (
github.com/tevino/abool/v2 v2.1.0

View file

@ -59,7 +59,12 @@ func warning() {
func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
fmt.Printf(
"amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n",
Version,
runtime.GOOS,
runtime.GOARCH,
)
return
}