refactor cfg & integrate

Signed-off-by: Mark Puha <marko10@inf.elte.hu>
This commit is contained in:
Mark Puha 2023-09-05 07:03:36 +02:00
parent 999cc3212d
commit 6c2c04ebd4
10 changed files with 573 additions and 197 deletions

View file

@ -1,75 +0,0 @@
package cfg
import "log"
func init() {
if IsAdvancedSecurityOn() {
if JunkPacketCount < 0 {
log.Fatalf("JunkPacketCount should be non negative")
}
if JunkPacketMaxSize <= JunkPacketMinSize {
log.Fatalf(
"MaxSize: %d; should be greater than MinSize: %d",
JunkPacketMaxSize,
JunkPacketMinSize,
)
}
const MaxSegmentSize = 2048 - 32
if JunkPacketMaxSize >= MaxSegmentSize {
log.Fatalf(
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
JunkPacketMaxSize,
MaxSegmentSize,
)
}
if 148+InitPacketJunkSize >= MaxSegmentSize {
log.Fatalf(
"init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d",
InitPacketJunkSize,
MaxSegmentSize,
)
}
if 92+ResponsePacketJunkSize >= MaxSegmentSize {
log.Fatalf(
"response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d",
ResponsePacketJunkSize,
MaxSegmentSize,
)
}
if 64+UnderLoadPacketJunkSize >= MaxSegmentSize {
log.Fatalf(
"underload packet size(64) + junkSize:%d; should be smaller than maxSegmentSize: %d",
UnderLoadPacketJunkSize,
MaxSegmentSize,
)
}
if 32+TransportPacketJunkSize >= MaxSegmentSize {
log.Fatalf(
"transport packet size(32) + junkSize:%d should be smaller than maxSegmentSize: %d",
TransportPacketJunkSize,
MaxSegmentSize,
)
}
if UnderLoadPacketJunkSize != 0 || TransportPacketJunkSize != 0 {
log.Fatal(
`UnderLoadPacketJunkSize and TransportPacketJunkSize;
are currently unimplemented and should be left 0`,
)
}
} else {
if InitPacketJunkSize != 0 ||
ResponsePacketJunkSize != 0 ||
UnderLoadPacketJunkSize != 0 ||
TransportPacketJunkSize != 0 {
log.Fatal("JunkSizes should be zero when advanced security on")
}
}
}
func IsAdvancedSecurityOn() bool {
return InitPacketMagicHeader != 1 ||
ResponsePacketMagicHeader != 2 ||
UnderloadPacketMagicHeader != 3 ||
TransportPacketMagicHeader != 4
}

View file

@ -1,11 +0,0 @@
junk_packet_count: 5
junk_packet_min_size: 10
junk_packet_max_size: 30
init_packet_junk_size: 0
response_packet_junk_size: 0
underload_packet_junk_size: 0
transport_packet_junk_size: 0
init_packet_magic_header: 1
response_packet_magic_header : 2
underload_packet_magic_header : 3
transport_packet_magic_header : 4

View file

@ -1,11 +0,0 @@
junk_packet_count: 5
junk_packet_min_size: 10
junk_packet_max_size: 30
init_packet_junk_size: 30
response_packet_junk_size: 50
underload_packet_junk_size: 0
transport_packet_junk_size: 0
init_packet_magic_header: 1234567
response_packet_magic_header : 7654321
underload_packet_magic_header : 12345687
transport_packet_magic_header : 146810

View file

@ -6,6 +6,7 @@
package device package device
import ( import (
"log"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -89,6 +90,18 @@ type Device struct {
ipcMutex sync.RWMutex ipcMutex sync.RWMutex
closed chan struct{} closed chan struct{}
log *Logger log *Logger
aSecCfg struct {
isOn bool
junkPacketCount int
junkPacketMinSize int
junkPacketMaxSize int
initPacketJunkSize int
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
}
} }
// deviceState represents the state of a Device. // deviceState represents the state of a Device.
@ -142,7 +155,10 @@ func (device *Device) changeState(want deviceState) (err error) {
old := device.deviceState() old := device.deviceState()
if old == deviceStateClosed { if old == deviceStateClosed {
// once closed, always closed // once closed, always closed
device.log.Verbosef("Interface closed, ignored requested state %s", want) device.log.Verbosef(
"Interface closed, ignored requested state %s",
want,
)
return nil return nil
} }
switch want { switch want {
@ -162,7 +178,12 @@ func (device *Device) changeState(want deviceState) (err error) {
err = errDown err = errDown
} }
} }
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) device.log.Verbosef(
"Interface state was %s, requested %s, now %s",
old,
want,
device.deviceState(),
)
return return
} }
@ -267,7 +288,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(
handshake.remoteStatic,
)
expiredPeers = append(expiredPeers, peer) expiredPeers = append(expiredPeers, peer)
} }
@ -411,7 +434,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.keypairs.RLock() peer.keypairs.RLock()
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) sendKeepalive := peer.keypairs.current != nil &&
!peer.keypairs.current.created.Add(RejectAfterTime).
Before(time.Now())
peer.keypairs.RUnlock() peer.keypairs.RUnlock()
if sendKeepalive { if sendKeepalive {
peer.SendKeepalive() peer.SendKeepalive()
@ -525,8 +550,12 @@ func (device *Device) BindUpdate() error {
// start receiving routines // start receiving routines
device.net.stopping.Add(len(recvFns)) device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.decryption.wg.Add(
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake len(recvFns),
) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(
len(recvFns),
) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize() batchSize := netc.bind.BatchSize()
for _, fn := range recvFns { for _, fn := range recvFns {
go device.RoutineReceiveIncoming(batchSize, fn) go device.RoutineReceiveIncoming(batchSize, fn)
@ -542,3 +571,57 @@ func (device *Device) BindClose() error {
device.net.Unlock() device.net.Unlock()
return err return err
} }
func (device *Device) isAdvancedSecurityOn() bool {
return device.aSecCfg.isOn
}
func (device *Device) handlePostConfig() {
if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketMaxSize >= 0 {
if device.aSecCfg.junkPacketMaxSize == device.aSecCfg.junkPacketMinSize {
device.aSecCfg.junkPacketMaxSize++ // to make rand gen work
} else if device.aSecCfg.junkPacketMaxSize < device.aSecCfg.junkPacketMinSize {
log.Fatalf(
"MaxSize: %d; should be greater than MinSize: %d",
device.aSecCfg.junkPacketMaxSize,
device.aSecCfg.junkPacketMinSize,
)
}
}
if device.aSecCfg.initPacketMagicHeader != 0 &&
device.aSecCfg.initPacketMagicHeader != 1 {
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
}
if device.aSecCfg.responsePacketMagicHeader != 0 &&
device.aSecCfg.responsePacketMagicHeader != 1 {
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
}
if device.aSecCfg.underloadPacketMagicHeader != 0 &&
device.aSecCfg.underloadPacketMagicHeader != 1 {
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
}
if device.aSecCfg.transportPacketMagicHeader != 0 &&
device.aSecCfg.transportPacketMagicHeader != 1 {
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
}
packetSizeToMsgType = map[int]uint32{
MessageInitiationSize + device.aSecCfg.initPacketJunkSize: MessageInitiationType,
MessageResponseSize + device.aSecCfg.responsePacketJunkSize: MessageResponseType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
}

View file

@ -91,6 +91,65 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
return return
} }
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
_, err = rand.Read(key2[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
pub1, pub2 := key1.publicKey(), key2.publicKey()
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "501",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.2/32",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", "127.0.0.1:%d",
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", "0",
"replace_peers", "true",
"jc", "5",
"jmin", "500",
"jmax", "501",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "1.0.0.1/32",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", "127.0.0.1:%d",
)
return
}
// A testPair is a pair of testPeers. // A testPair is a pair of testPeers.
type testPair [2]testPeer type testPair [2]testPeer
@ -115,7 +174,11 @@ func (d SendDirection) String() string {
return "pong" return "pong"
} }
func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { func (pair *testPair) Send(
tb testing.TB,
ping SendDirection,
done chan struct{},
) {
tb.Helper() tb.Helper()
p0, p1 := pair[0], pair[1] p0, p1 := pair[0], pair[1]
if !ping { if !ping {
@ -149,8 +212,16 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
} }
// genTestPair creates a testPair. // genTestPair creates a testPair.
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { func genTestPair(
cfg, endpointCfg := genConfigs(tb) tb testing.TB,
realSocket, withASecurity bool,
) (pair testPair) {
var cfg, endpointCfg [2]string
if withASecurity {
cfg, endpointCfg = genASecurityConfigs(tb)
} else {
cfg, endpointCfg = genConfigs(tb)
}
var binds [2]conn.Bind var binds [2]conn.Bind
if realSocket { if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
@ -166,7 +237,11 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
if _, ok := tb.(*testing.B); ok && !testing.Verbose() { if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError level = LogLevelError
} }
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) p.dev = NewDevice(
p.tun.TUN(),
binds[i],
NewLogger(level, fmt.Sprintf("dev%d: ", i)),
)
if err := p.dev.IpcSet(cfg[i]); err != nil { if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err) tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close() p.dev.Close()
@ -194,7 +269,18 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
func TestTwoDevicePing(t *testing.T) { func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t) goroutineLeakCheck(t)
pair := genTestPair(t, true) pair := genTestPair(t, true, false)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
})
}
func TestTwoDevicePingASecurity(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) { t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil) pair.Send(t, Ping, nil)
}) })
@ -209,10 +295,15 @@ func TestUpDown(t *testing.T) {
const otrials = 10 const otrials = 10
for n := 0; n < otrials; n++ { for n := 0; n < otrials; n++ {
pair := genTestPair(t, false) pair := genTestPair(t, false, false)
for i := range pair { for i := range pair {
for k := range pair[i].dev.peers.keyMap { for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) pair[i].dev.IpcSet(
fmt.Sprintf(
"public_key=%s\npersistent_keepalive_interval=1\n",
hex.EncodeToString(k[:]),
),
)
} }
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@ -224,11 +315,19 @@ func TestUpDown(t *testing.T) {
if err := d.Up(); err != nil { if err := d.Up(); err != nil {
t.Errorf("failed up bring up device: %v", err) t.Errorf("failed up bring up device: %v", err)
} }
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) time.Sleep(
time.Duration(
rand.Intn(int(time.Nanosecond * (0x10000 - 1))),
),
)
if err := d.Down(); err != nil { if err := d.Down(); err != nil {
t.Errorf("failed to bring down device: %v", err) t.Errorf("failed to bring down device: %v", err)
} }
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) time.Sleep(
time.Duration(
rand.Intn(int(time.Nanosecond * (0x10000 - 1))),
),
)
} }
}(pair[i].dev) }(pair[i].dev)
} }
@ -243,7 +342,7 @@ func TestUpDown(t *testing.T) {
// TestConcurrencySafety does other things concurrently with tunnel use. // TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races. // It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) { func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t, true) pair := genTestPair(t, true, false)
done := make(chan struct{}) done := make(chan struct{})
const warmupIters = 10 const warmupIters = 10
@ -294,8 +393,14 @@ func TestConcurrencySafety(t *testing.T) {
// Change private keys concurrently with tunnel use. // Change private keys concurrently with tunnel use.
t.Run("privateKey", func(t *testing.T) { t.Run("privateKey", func(t *testing.T) {
bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") bad := uapiCfg(
good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) "private_key",
"7777777777777777777777777777777777777777777777777777777777777777",
)
good := uapiCfg(
"private_key",
hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]),
)
// Set iters to a large number like 1000 to flush out data races quickly. // Set iters to a large number like 1000 to flush out data races quickly.
// Don't leave it large. That can cause logical races // Don't leave it large. That can cause logical races
// in which the handshake is interleaved with key changes // in which the handshake is interleaved with key changes
@ -324,7 +429,7 @@ func TestConcurrencySafety(t *testing.T) {
} }
func BenchmarkLatency(b *testing.B) { func BenchmarkLatency(b *testing.B) {
pair := genTestPair(b, true) pair := genTestPair(b, true, false)
// Establish a connection. // Establish a connection.
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
@ -338,7 +443,7 @@ func BenchmarkLatency(b *testing.B) {
} }
func BenchmarkThroughput(b *testing.B) { func BenchmarkThroughput(b *testing.B) {
pair := genTestPair(b, true) pair := genTestPair(b, true, false)
// Establish a connection. // Establish a connection.
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
@ -382,7 +487,7 @@ func BenchmarkThroughput(b *testing.B) {
} }
func BenchmarkUAPIGet(b *testing.B) { func BenchmarkUAPIGet(b *testing.B) {
pair := genTestPair(b, true) pair := genTestPair(b, true, false)
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil) pair.Send(b, Pong, nil)
b.ReportAllocs() b.ReportAllocs()
@ -415,7 +520,11 @@ func goroutineLeakCheck(t *testing.T) {
endGoroutines, endStacks := goroutines() endGoroutines, endStacks := goroutines()
t.Logf("starting stacks:\n%s\n", startStacks) t.Logf("starting stacks:\n%s\n", startStacks)
t.Logf("ending stacks:\n%s\n", endStacks) t.Logf("ending stacks:\n%s\n", endStacks)
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) t.Fatalf(
"expected %d goroutines, got %d, leak?",
startGoroutines,
endGoroutines,
)
}) })
} }
@ -423,29 +532,65 @@ type fakeBindSized struct {
size int size int
} }
func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { func (b *fakeBindSized) Open(
port uint16,
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil return nil, 0, nil
} }
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil } func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } func (b *fakeBindSized) SetMark(
func (b *fakeBindSized) BatchSize() int { return b.size } mark uint32,
) error {
return nil
}
func (b *fakeBindSized) Send(
bufs [][]byte,
ep conn.Endpoint,
) error {
return nil
}
func (b *fakeBindSized) ParseEndpoint(
s string,
) (conn.Endpoint, error) {
return nil, nil
}
func (b *fakeBindSized) BatchSize() int { return b.size }
type fakeTUNDeviceSized struct { type fakeTUNDeviceSized struct {
size int size int
} }
func (t *fakeTUNDeviceSized) File() *os.File { return nil } func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
func (t *fakeTUNDeviceSized) Read(
bufs [][]byte,
sizes []int,
offset int,
) (n int, err error) {
return 0, nil return 0, nil
} }
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } func (t *fakeTUNDeviceSized) Write(
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } bufs [][]byte,
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } offset int,
func (t *fakeTUNDeviceSized) Close() error { return nil } ) (int, error) {
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } return 0, nil
}
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func TestBatchSize(t *testing.T) { func TestBatchSize(t *testing.T) {
d := Device{} d := Device{}

View file

@ -15,7 +15,6 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/cfg"
"golang.zx2c4.com/wireguard/tai64n" "golang.zx2c4.com/wireguard/tai64n"
) )
@ -53,11 +52,11 @@ const (
WGLabelCookie = "cookie--" WGLabelCookie = "cookie--"
) )
const ( var (
MessageInitiationType = cfg.InitPacketMagicHeader MessageInitiationType uint32 = 1
MessageResponseType = cfg.ResponsePacketMagicHeader MessageResponseType uint32 = 2
MessageCookieReplyType = cfg.UnderloadPacketMagicHeader MessageCookieReplyType uint32 = 3
MessageTransportType = cfg.TransportPacketMagicHeader MessageTransportType uint32 = 4
) )
const ( const (
@ -76,19 +75,9 @@ const (
MessageTransportOffsetContent = 16 MessageTransportOffsetContent = 16
) )
var packetSizeToMsgType = map[int]uint32{ var packetSizeToMsgType map[int]uint32
MessageInitiationSize + cfg.InitPacketJunkSize: MessageInitiationType,
MessageResponseSize + cfg.ResponsePacketJunkSize: MessageResponseType,
MessageCookieReplySize + cfg.UnderLoadPacketJunkSize: MessageCookieReplyType,
MessageTransportSize + cfg.TransportPacketJunkSize: MessageTransportType,
}
var msgTypeToJunkSize = map[uint32]int{ var msgTypeToJunkSize map[uint32]int
MessageInitiationType: cfg.InitPacketJunkSize,
MessageResponseType: cfg.ResponsePacketJunkSize,
MessageCookieReplyType: cfg.UnderLoadPacketJunkSize,
MessageTransportType: cfg.TransportPacketJunkSize,
}
/* Type is an 8-bit field, followed by 3 nul bytes, /* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder * by marshalling the messages in little-endian byteorder

View file

@ -16,7 +16,6 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/cfg"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
) )
@ -142,12 +141,10 @@ func (device *Device) RoutineReceiveIncoming(
// check size of packet // check size of packet
packet := bufsArrs[i][:size] packet := bufsArrs[i][:size]
if cfg.IsAdvancedSecurityOn() { if device.isAdvancedSecurityOn() {
var junkSize int var junkSize int
if mapMsgType, ok := packetSizeToMsgType[size]; ok { if msgType, ok := packetSizeToMsgType[size]; ok {
junkSize = msgTypeToJunkSize[mapMsgType] junkSize = msgTypeToJunkSize[msgType]
} else {
junkSize = cfg.TransportPacketJunkSize
} }
// shift junk // shift junk
packet = packet[junkSize:] packet = packet[junkSize:]

View file

@ -19,7 +19,6 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/cfg"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@ -128,15 +127,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
} }
// so only packet processed for cookie generation // so only packet processed for cookie generation
var junkedHeader []byte var junkedHeader []byte
if cfg.IsAdvancedSecurityOn() { if peer.device.isAdvancedSecurityOn() {
err = peer.sendJunkPackets() err = peer.sendJunkPackets()
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
return err return err
} }
var buf [cfg.InitPacketJunkSize]byte buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, cfg.InitPacketJunkSize) err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
return err return err
@ -183,10 +182,10 @@ func (peer *Peer) SendHandshakeResponse() error {
return err return err
} }
var junkedHeader []byte var junkedHeader []byte
if cfg.IsAdvancedSecurityOn() { if peer.device.isAdvancedSecurityOn() {
var buf [cfg.ResponsePacketJunkSize]byte buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, cfg.ResponsePacketJunkSize) err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
return err return err
@ -472,11 +471,11 @@ top:
} }
func (peer *Peer) sendJunkPackets() error { func (peer *Peer) sendJunkPackets() error {
junks := make([][]byte, 0, cfg.JunkPacketCount) junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
for i := 0; i < cfg.JunkPacketCount; i++ { for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
packetSize := rand.Intn( packetSize := rand.Intn(
cfg.JunkPacketMaxSize-cfg.JunkPacketMinSize, peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
) + cfg.JunkPacketMinSize ) + peer.device.aSecCfg.junkPacketMinSize
junk, err := randomJunkWithSize(packetSize) junk, err := randomJunkWithSize(packetSize)
if err != nil { if err != nil {

View file

@ -11,6 +11,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"net/netip" "net/netip"
"strconv" "strconv"
@ -97,6 +98,36 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark) sendf("fwmark=%d", device.net.fwmark)
} }
if device.isAdvancedSecurityOn() {
if device.aSecCfg.junkPacketCount != 0 {
sendf("jc=%d", device.aSecCfg.junkPacketCount)
}
if device.aSecCfg.junkPacketMinSize != 0 {
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
}
if device.aSecCfg.junkPacketMaxSize != 0 {
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
}
if device.aSecCfg.initPacketJunkSize != 0 {
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
}
if device.aSecCfg.responsePacketJunkSize != 0 {
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
}
if device.aSecCfg.initPacketMagicHeader != 0 {
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
}
if device.aSecCfg.responsePacketMagicHeader != 0 {
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
}
if device.aSecCfg.underloadPacketMagicHeader != 0 {
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
}
if device.aSecCfg.transportPacketMagicHeader != 0 {
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
}
}
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
// Serialize peer state. // Serialize peer state.
// Do the work in an anonymous function so that we can use defer. // Do the work in an anonymous function so that we can use defer.
@ -119,12 +150,18 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("last_handshake_time_nsec=%d", nano) sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", peer.txBytes.Load()) sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", peer.rxBytes.Load()) sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) sendf(
"persistent_keepalive_interval=%d",
peer.persistentKeepaliveInterval.Load(),
)
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { device.allowedips.EntriesForPeer(
sendf("allowed_ip=%s", prefix.String()) peer,
return true func(prefix netip.Prefix) bool {
}) sendf("allowed_ip=%s", prefix.String())
return true
},
)
}() }()
} }
}() }()
@ -162,7 +199,11 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
} }
key, value, ok := strings.Cut(line, "=") key, value, ok := strings.Cut(line, "=")
if !ok { if !ok {
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line) return ipcErrorf(
ipc.IpcErrorProtocol,
"failed to parse line %q",
line,
)
} }
if key == "public_key" { if key == "public_key" {
@ -188,6 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err return err
} }
} }
device.handlePostConfig()
peer.handlePostConfig() peer.handlePostConfig()
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
@ -202,7 +244,11 @@ func (device *Device) handleDeviceLine(key, value string) error {
var sk NoisePrivateKey var sk NoisePrivateKey
err := sk.FromMaybeZeroHex(value) err := sk.FromMaybeZeroHex(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set private_key: %w",
err,
)
} }
device.log.Verbosef("UAPI: Updating private key") device.log.Verbosef("UAPI: Updating private key")
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
@ -210,7 +256,11 @@ func (device *Device) handleDeviceLine(key, value string) error {
case "listen_port": case "listen_port":
port, err := strconv.ParseUint(value, 10, 16) port, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to parse listen_port: %w",
err,
)
} }
// update port and rebind // update port and rebind
@ -221,7 +271,11 @@ func (device *Device) handleDeviceLine(key, value string) error {
device.net.Unlock() device.net.Unlock()
if err := device.BindUpdate(); err != nil { if err := device.BindUpdate(); err != nil {
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) return ipcErrorf(
ipc.IpcErrorPortInUse,
"failed to set listen_port: %w",
err,
)
} }
case "fwmark": case "fwmark":
@ -232,18 +286,171 @@ func (device *Device) handleDeviceLine(key, value string) error {
device.log.Verbosef("UAPI: Updating fwmark") device.log.Verbosef("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(mark)); err != nil { if err := device.BindSetMark(uint32(mark)); err != nil {
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) return ipcErrorf(
ipc.IpcErrorPortInUse,
"failed to update fwmark: %w",
err,
)
} }
case "replace_peers": case "replace_peers":
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set replace_peers, invalid value: %v",
value,
)
} }
device.log.Verbosef("UAPI: Removing all peers") device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers() device.RemoveAllPeers()
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse junk_packet_count %w",
err,
)
}
if junkPacketCount < 0 {
log.Fatalf("JunkPacketCount should be non negative")
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
device.aSecCfg.isOn = true
device.aSecCfg.junkPacketCount = junkPacketCount
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse junk_packet_min_size %w",
err,
)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
device.aSecCfg.isOn = true
device.aSecCfg.junkPacketMinSize = junkPacketMinSize
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse junk_packet_max_size %w",
err,
)
}
if junkPacketMaxSize >= MaxSegmentSize {
log.Fatalf(
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
junkPacketMaxSize,
MaxSegmentSize,
)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
device.aSecCfg.isOn = true
device.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse init_packet_junk_size %w",
err,
)
}
if 148+initPacketJunkSize >= MaxSegmentSize {
log.Fatalf(
`init header size(148) + junkSize:%d;
should be smaller than maxSegmentSize: %d`,
initPacketJunkSize,
MaxSegmentSize,
)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
device.aSecCfg.isOn = true
device.aSecCfg.initPacketJunkSize = initPacketJunkSize
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse response_packet_junk_size %w",
err,
)
}
if 92+responsePacketJunkSize >= MaxSegmentSize {
log.Fatalf(
`response header size(92) + junkSize:%d;
should be smaller than maxSegmentSize: %d`,
responsePacketJunkSize,
MaxSegmentSize,
)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
device.aSecCfg.isOn = true
device.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse init_packet_magic_header %w",
err,
)
}
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.aSecCfg.isOn = true
device.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse response_packet_magic_header %w",
err,
)
}
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.aSecCfg.isOn = true
device.aSecCfg.responsePacketMagicHeader = uint32(
responsePacketMagicHeader,
)
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse underload_packet_magic_header %w",
err,
)
}
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.aSecCfg.isOn = true
device.aSecCfg.underloadPacketMagicHeader = uint32(
underloadPacketMagicHeader,
)
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(
ipc.IpcErrorInvalid,
"faield to parse transport_packet_magic_header %w",
err,
)
}
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.aSecCfg.isOn = true
device.aSecCfg.transportPacketMagicHeader = uint32(
transportPacketMagicHeader,
)
default: default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) return ipcErrorf(
ipc.IpcErrorInvalid,
"invalid UAPI device key: %v",
key,
)
} }
return nil return nil
@ -262,7 +469,8 @@ func (peer *ipcSetPeer) handlePostConfig() {
return return
} }
if peer.created { if peer.created {
peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil peer.disableRoaming = peer.device.net.brokenRoaming &&
peer.endpoint != nil
} }
if peer.device.isUp() { if peer.device.isUp() {
peer.Start() peer.Start()
@ -273,12 +481,19 @@ func (peer *ipcSetPeer) handlePostConfig() {
} }
} }
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { func (device *Device) handlePublicKeyLine(
peer *ipcSetPeer,
value string,
) error {
// Load/create the peer we are configuring. // Load/create the peer we are configuring.
var publicKey NoisePublicKey var publicKey NoisePublicKey
err := publicKey.FromHex(value) err := publicKey.FromHex(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to get peer by public key: %w",
err,
)
} }
// Ignore peer with the same public key as this device. // Ignore peer with the same public key as this device.
@ -296,19 +511,30 @@ func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error
if peer.created { if peer.created {
peer.Peer, err = device.NewPeer(publicKey) peer.Peer, err = device.NewPeer(publicKey)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to create new peer: %w",
err,
)
} }
device.log.Verbosef("%v - UAPI: Created", peer.Peer) device.log.Verbosef("%v - UAPI: Created", peer.Peer)
} }
return nil return nil
} }
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { func (device *Device) handlePeerLine(
peer *ipcSetPeer,
key, value string,
) error {
switch key { switch key {
case "update_only": case "update_only":
// allow disabling of creation // allow disabling of creation
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set update only, invalid value: %v",
value,
)
} }
if peer.created && !peer.dummy { if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
@ -319,7 +545,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "remove": case "remove":
// remove currently selected peer from device // remove currently selected peer from device
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set remove, invalid value: %v",
value,
)
} }
if !peer.dummy { if !peer.dummy {
device.log.Verbosef("%v - UAPI: Removing", peer.Peer) device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
@ -336,25 +566,41 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set preshared key: %w",
err,
)
} }
case "endpoint": case "endpoint":
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
endpoint, err := device.net.bind.ParseEndpoint(value) endpoint, err := device.net.bind.ParseEndpoint(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set endpoint %v: %w",
value,
err,
)
} }
peer.Lock() peer.Lock()
defer peer.Unlock() defer peer.Unlock()
peer.endpoint = endpoint peer.endpoint = endpoint
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) device.log.Verbosef(
"%v - UAPI: Updating persistent keepalive interval",
peer.Peer,
)
secs, err := strconv.ParseUint(value, 10, 16) secs, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set persistent keepalive interval: %w",
err,
)
} }
old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
@ -365,7 +611,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "replace_allowed_ips": case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to replace allowedips, invalid value: %v",
value,
)
} }
if peer.dummy { if peer.dummy {
return nil return nil
@ -376,7 +626,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
prefix, err := netip.ParsePrefix(value) prefix, err := netip.ParsePrefix(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) return ipcErrorf(
ipc.IpcErrorInvalid,
"failed to set allowed ip: %w",
err,
)
} }
if peer.dummy { if peer.dummy {
return nil return nil
@ -385,7 +639,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "protocol_version": case "protocol_version":
if value != "1" { if value != "1" {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) return ipcErrorf(
ipc.IpcErrorInvalid,
"invalid protocol version: %v",
value,
)
} }
default: default:
@ -433,7 +691,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
return return
} }
if nextByte != '\n' { if nextByte != '\n' {
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) err = ipcErrorf(
ipc.IpcErrorInvalid,
"trailing character in UAPI get: %q",
nextByte,
)
break break
} }
err = device.IpcGetOperation(buffered.Writer) err = device.IpcGetOperation(buffered.Writer)

View file

@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"testing" "testing"
"golang.zx2c4.com/wireguard/cfg"
) )
func Test_randomJunktWithSize(t *testing.T) { func Test_randomJunktWithSize(t *testing.T) {
@ -19,7 +17,7 @@ func Test_appendJunk(t *testing.T) {
buffer := bytes.NewBuffer([]byte(s)) buffer := bytes.NewBuffer([]byte(s))
err := appendJunk(buffer, 30) err := appendJunk(buffer, 30)
if err != nil && if err != nil &&
buffer.Len() != len(s)+int(cfg.InitPacketJunkSize) { buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match") t.Errorf("appendWithJunk() size don't match")
} }
read := make([]byte, 50) read := make([]byte, 50)