From 6c2c04ebd46febe7b0a941ff6f6bb5e7831ccec0 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Tue, 5 Sep 2023 07:03:36 +0200 Subject: [PATCH] refactor cfg & integrate Signed-off-by: Mark Puha --- cfg/cfg_def.go | 75 ------ cfg/settings/advanced_security_off.yml | 11 - cfg/settings/default.yml | 11 - device/device.go | 95 +++++++- device/device_test.go | 203 +++++++++++++--- device/noise-protocol.go | 25 +- device/receive.go | 9 +- device/send.go | 21 +- device/uapi.go | 316 ++++++++++++++++++++++--- device/util_test.go | 4 +- 10 files changed, 573 insertions(+), 197 deletions(-) delete mode 100644 cfg/cfg_def.go delete mode 100644 cfg/settings/advanced_security_off.yml delete mode 100644 cfg/settings/default.yml diff --git a/cfg/cfg_def.go b/cfg/cfg_def.go deleted file mode 100644 index f74fb90..0000000 --- a/cfg/cfg_def.go +++ /dev/null @@ -1,75 +0,0 @@ -package cfg - -import "log" - -func init() { - if IsAdvancedSecurityOn() { - if JunkPacketCount < 0 { - log.Fatalf("JunkPacketCount should be non negative") - } - if JunkPacketMaxSize <= JunkPacketMinSize { - log.Fatalf( - "MaxSize: %d; should be greater than MinSize: %d", - JunkPacketMaxSize, - JunkPacketMinSize, - ) - } - const MaxSegmentSize = 2048 - 32 - if JunkPacketMaxSize >= MaxSegmentSize { - log.Fatalf( - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - JunkPacketMaxSize, - MaxSegmentSize, - ) - } - if 148+InitPacketJunkSize >= MaxSegmentSize { - log.Fatalf( - "init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d", - InitPacketJunkSize, - MaxSegmentSize, - ) - } - if 92+ResponsePacketJunkSize >= MaxSegmentSize { - log.Fatalf( - "response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d", - ResponsePacketJunkSize, - MaxSegmentSize, - ) - } - if 64+UnderLoadPacketJunkSize >= MaxSegmentSize { - log.Fatalf( - "underload packet size(64) + junkSize:%d; should be smaller than maxSegmentSize: %d", - UnderLoadPacketJunkSize, - MaxSegmentSize, - ) - } - if 32+TransportPacketJunkSize >= MaxSegmentSize { - log.Fatalf( - "transport packet size(32) + junkSize:%d should be smaller than maxSegmentSize: %d", - TransportPacketJunkSize, - MaxSegmentSize, - ) - } - if UnderLoadPacketJunkSize != 0 || TransportPacketJunkSize != 0 { - log.Fatal( - `UnderLoadPacketJunkSize and TransportPacketJunkSize; - are currently unimplemented and should be left 0`, - ) - } - } else { - if InitPacketJunkSize != 0 || - ResponsePacketJunkSize != 0 || - UnderLoadPacketJunkSize != 0 || - TransportPacketJunkSize != 0 { - - log.Fatal("JunkSizes should be zero when advanced security on") - } - } -} - -func IsAdvancedSecurityOn() bool { - return InitPacketMagicHeader != 1 || - ResponsePacketMagicHeader != 2 || - UnderloadPacketMagicHeader != 3 || - TransportPacketMagicHeader != 4 -} diff --git a/cfg/settings/advanced_security_off.yml b/cfg/settings/advanced_security_off.yml deleted file mode 100644 index 83c3161..0000000 --- a/cfg/settings/advanced_security_off.yml +++ /dev/null @@ -1,11 +0,0 @@ -junk_packet_count: 5 -junk_packet_min_size: 10 -junk_packet_max_size: 30 -init_packet_junk_size: 0 -response_packet_junk_size: 0 -underload_packet_junk_size: 0 -transport_packet_junk_size: 0 -init_packet_magic_header: 1 -response_packet_magic_header : 2 -underload_packet_magic_header : 3 -transport_packet_magic_header : 4 \ No newline at end of file diff --git a/cfg/settings/default.yml b/cfg/settings/default.yml deleted file mode 100644 index 2728914..0000000 --- a/cfg/settings/default.yml +++ /dev/null @@ -1,11 +0,0 @@ -junk_packet_count: 5 -junk_packet_min_size: 10 -junk_packet_max_size: 30 -init_packet_junk_size: 30 -response_packet_junk_size: 50 -underload_packet_junk_size: 0 -transport_packet_junk_size: 0 -init_packet_magic_header: 1234567 -response_packet_magic_header : 7654321 -underload_packet_magic_header : 12345687 -transport_packet_magic_header : 146810 \ No newline at end of file diff --git a/device/device.go b/device/device.go index 1af9fe0..7405c1c 100644 --- a/device/device.go +++ b/device/device.go @@ -6,6 +6,7 @@ package device import ( + "log" "runtime" "sync" "sync/atomic" @@ -89,6 +90,18 @@ type Device struct { ipcMutex sync.RWMutex closed chan struct{} log *Logger + aSecCfg struct { + isOn bool + junkPacketCount int + junkPacketMinSize int + junkPacketMaxSize int + initPacketJunkSize int + responsePacketJunkSize int + initPacketMagicHeader uint32 + responsePacketMagicHeader uint32 + underloadPacketMagicHeader uint32 + transportPacketMagicHeader uint32 + } } // deviceState represents the state of a Device. @@ -142,7 +155,10 @@ func (device *Device) changeState(want deviceState) (err error) { old := device.deviceState() if old == deviceStateClosed { // once closed, always closed - device.log.Verbosef("Interface closed, ignored requested state %s", want) + device.log.Verbosef( + "Interface closed, ignored requested state %s", + want, + ) return nil } switch want { @@ -162,7 +178,12 @@ func (device *Device) changeState(want deviceState) (err error) { err = errDown } } - device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) + device.log.Verbosef( + "Interface state was %s, requested %s, now %s", + old, + want, + device.deviceState(), + ) return } @@ -267,7 +288,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret( + handshake.remoteStatic, + ) expiredPeers = append(expiredPeers, peer) } @@ -411,7 +434,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { device.peers.RLock() for _, peer := range device.peers.keyMap { peer.keypairs.RLock() - sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) + sendKeepalive := peer.keypairs.current != nil && + !peer.keypairs.current.created.Add(RejectAfterTime). + Before(time.Now()) peer.keypairs.RUnlock() if sendKeepalive { peer.SendKeepalive() @@ -525,8 +550,12 @@ func (device *Device) BindUpdate() error { // start receiving routines device.net.stopping.Add(len(recvFns)) - device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption - device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + device.queue.decryption.wg.Add( + len(recvFns), + ) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add( + len(recvFns), + ) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake batchSize := netc.bind.BatchSize() for _, fn := range recvFns { go device.RoutineReceiveIncoming(batchSize, fn) @@ -542,3 +571,57 @@ func (device *Device) BindClose() error { device.net.Unlock() return err } +func (device *Device) isAdvancedSecurityOn() bool { + return device.aSecCfg.isOn +} + +func (device *Device) handlePostConfig() { + if device.isAdvancedSecurityOn() { + if device.aSecCfg.junkPacketMaxSize >= 0 { + if device.aSecCfg.junkPacketMaxSize == device.aSecCfg.junkPacketMinSize { + device.aSecCfg.junkPacketMaxSize++ // to make rand gen work + } else if device.aSecCfg.junkPacketMaxSize < device.aSecCfg.junkPacketMinSize { + log.Fatalf( + "MaxSize: %d; should be greater than MinSize: %d", + device.aSecCfg.junkPacketMaxSize, + device.aSecCfg.junkPacketMinSize, + ) + } + } + + if device.aSecCfg.initPacketMagicHeader != 0 && + device.aSecCfg.initPacketMagicHeader != 1 { + + MessageInitiationType = device.aSecCfg.initPacketMagicHeader + } + if device.aSecCfg.responsePacketMagicHeader != 0 && + device.aSecCfg.responsePacketMagicHeader != 1 { + + MessageResponseType = device.aSecCfg.responsePacketMagicHeader + } + if device.aSecCfg.underloadPacketMagicHeader != 0 && + device.aSecCfg.underloadPacketMagicHeader != 1 { + + MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader + } + if device.aSecCfg.transportPacketMagicHeader != 0 && + device.aSecCfg.transportPacketMagicHeader != 1 { + + MessageTransportType = device.aSecCfg.transportPacketMagicHeader + } + + packetSizeToMsgType = map[int]uint32{ + MessageInitiationSize + device.aSecCfg.initPacketJunkSize: MessageInitiationType, + MessageResponseSize + device.aSecCfg.responsePacketJunkSize: MessageResponseType, + MessageCookieReplySize: MessageCookieReplyType, + MessageTransportSize: MessageTransportType, + } + + msgTypeToJunkSize = map[uint32]int{ + MessageInitiationType: device.aSecCfg.initPacketJunkSize, + MessageResponseType: device.aSecCfg.responsePacketJunkSize, + MessageCookieReplyType: 0, + MessageTransportType: 0, + } + } +} diff --git a/device/device_test.go b/device/device_test.go index fff172b..6a4dc7d 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -91,6 +91,65 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { return } +func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { + var key1, key2 NoisePrivateKey + _, err := rand.Read(key1[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + _, err = rand.Read(key2[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + pub1, pub2 := key1.publicKey(), key2.publicKey() + + cfgs[0] = uapiCfg( + "private_key", hex.EncodeToString(key1[:]), + "listen_port", "0", + "replace_peers", "true", + "jc", "5", + "jmin", "500", + "jmax", "501", + "s1", "30", + "s2", "40", + "h1", "123456", + "h2", "67543", + "h4", "32345", + "h3", "123123", + "public_key", hex.EncodeToString(pub2[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.2/32", + ) + endpointCfgs[0] = uapiCfg( + "public_key", hex.EncodeToString(pub2[:]), + "endpoint", "127.0.0.1:%d", + ) + cfgs[1] = uapiCfg( + "private_key", hex.EncodeToString(key2[:]), + "listen_port", "0", + "replace_peers", "true", + "jc", "5", + "jmin", "500", + "jmax", "501", + "s1", "30", + "s2", "40", + "h1", "123456", + "h2", "67543", + "h4", "32345", + "h3", "123123", + "public_key", hex.EncodeToString(pub1[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.1/32", + ) + endpointCfgs[1] = uapiCfg( + "public_key", hex.EncodeToString(pub1[:]), + "endpoint", "127.0.0.1:%d", + ) + return +} + // A testPair is a pair of testPeers. type testPair [2]testPeer @@ -115,7 +174,11 @@ func (d SendDirection) String() string { return "pong" } -func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { +func (pair *testPair) Send( + tb testing.TB, + ping SendDirection, + done chan struct{}, +) { tb.Helper() p0, p1 := pair[0], pair[1] if !ping { @@ -149,8 +212,16 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{} } // genTestPair creates a testPair. -func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { - cfg, endpointCfg := genConfigs(tb) +func genTestPair( + tb testing.TB, + realSocket, withASecurity bool, +) (pair testPair) { + var cfg, endpointCfg [2]string + if withASecurity { + cfg, endpointCfg = genASecurityConfigs(tb) + } else { + cfg, endpointCfg = genConfigs(tb) + } var binds [2]conn.Bind if realSocket { binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() @@ -166,7 +237,11 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError } - p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) + p.dev = NewDevice( + p.tun.TUN(), + binds[i], + NewLogger(level, fmt.Sprintf("dev%d: ", i)), + ) if err := p.dev.IpcSet(cfg[i]); err != nil { tb.Errorf("failed to configure device %d: %v", i, err) p.dev.Close() @@ -194,7 +269,18 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { func TestTwoDevicePing(t *testing.T) { goroutineLeakCheck(t) - pair := genTestPair(t, true) + pair := genTestPair(t, true, false) + t.Run("ping 1.0.0.1", func(t *testing.T) { + pair.Send(t, Ping, nil) + }) + t.Run("ping 1.0.0.2", func(t *testing.T) { + pair.Send(t, Pong, nil) + }) +} + +func TestTwoDevicePingASecurity(t *testing.T) { + goroutineLeakCheck(t) + pair := genTestPair(t, true, true) t.Run("ping 1.0.0.1", func(t *testing.T) { pair.Send(t, Ping, nil) }) @@ -209,10 +295,15 @@ func TestUpDown(t *testing.T) { const otrials = 10 for n := 0; n < otrials; n++ { - pair := genTestPair(t, false) + pair := genTestPair(t, false, false) for i := range pair { for k := range pair[i].dev.peers.keyMap { - pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) + pair[i].dev.IpcSet( + fmt.Sprintf( + "public_key=%s\npersistent_keepalive_interval=1\n", + hex.EncodeToString(k[:]), + ), + ) } } var wg sync.WaitGroup @@ -224,11 +315,19 @@ func TestUpDown(t *testing.T) { if err := d.Up(); err != nil { t.Errorf("failed up bring up device: %v", err) } - time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + time.Sleep( + time.Duration( + rand.Intn(int(time.Nanosecond * (0x10000 - 1))), + ), + ) if err := d.Down(); err != nil { t.Errorf("failed to bring down device: %v", err) } - time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + time.Sleep( + time.Duration( + rand.Intn(int(time.Nanosecond * (0x10000 - 1))), + ), + ) } }(pair[i].dev) } @@ -243,7 +342,7 @@ func TestUpDown(t *testing.T) { // TestConcurrencySafety does other things concurrently with tunnel use. // It is intended to be used with the race detector to catch data races. func TestConcurrencySafety(t *testing.T) { - pair := genTestPair(t, true) + pair := genTestPair(t, true, false) done := make(chan struct{}) const warmupIters = 10 @@ -294,8 +393,14 @@ func TestConcurrencySafety(t *testing.T) { // Change private keys concurrently with tunnel use. t.Run("privateKey", func(t *testing.T) { - bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") - good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) + bad := uapiCfg( + "private_key", + "7777777777777777777777777777777777777777777777777777777777777777", + ) + good := uapiCfg( + "private_key", + hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]), + ) // Set iters to a large number like 1000 to flush out data races quickly. // Don't leave it large. That can cause logical races // in which the handshake is interleaved with key changes @@ -324,7 +429,7 @@ func TestConcurrencySafety(t *testing.T) { } func BenchmarkLatency(b *testing.B) { - pair := genTestPair(b, true) + pair := genTestPair(b, true, false) // Establish a connection. pair.Send(b, Ping, nil) @@ -338,7 +443,7 @@ func BenchmarkLatency(b *testing.B) { } func BenchmarkThroughput(b *testing.B) { - pair := genTestPair(b, true) + pair := genTestPair(b, true, false) // Establish a connection. pair.Send(b, Ping, nil) @@ -382,7 +487,7 @@ func BenchmarkThroughput(b *testing.B) { } func BenchmarkUAPIGet(b *testing.B) { - pair := genTestPair(b, true) + pair := genTestPair(b, true, false) pair.Send(b, Ping, nil) pair.Send(b, Pong, nil) b.ReportAllocs() @@ -415,7 +520,11 @@ func goroutineLeakCheck(t *testing.T) { endGoroutines, endStacks := goroutines() t.Logf("starting stacks:\n%s\n", startStacks) t.Logf("ending stacks:\n%s\n", endStacks) - t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) + t.Fatalf( + "expected %d goroutines, got %d, leak?", + startGoroutines, + endGoroutines, + ) }) } @@ -423,29 +532,65 @@ type fakeBindSized struct { size int } -func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { +func (b *fakeBindSized) Open( + port uint16, +) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 0, nil } -func (b *fakeBindSized) Close() error { return nil } -func (b *fakeBindSized) SetMark(mark uint32) error { return nil } -func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } -func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } -func (b *fakeBindSized) BatchSize() int { return b.size } + +func (b *fakeBindSized) Close() error { return nil } + +func (b *fakeBindSized) SetMark( + mark uint32, +) error { + return nil +} + +func (b *fakeBindSized) Send( + bufs [][]byte, + ep conn.Endpoint, +) error { + return nil +} + +func (b *fakeBindSized) ParseEndpoint( + s string, +) (conn.Endpoint, error) { + return nil, nil +} + +func (b *fakeBindSized) BatchSize() int { return b.size } type fakeTUNDeviceSized struct { size int } func (t *fakeTUNDeviceSized) File() *os.File { return nil } -func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + +func (t *fakeTUNDeviceSized) Read( + bufs [][]byte, + sizes []int, + offset int, +) (n int, err error) { return 0, nil } -func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } -func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } -func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } -func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } -func (t *fakeTUNDeviceSized) Close() error { return nil } -func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } + +func (t *fakeTUNDeviceSized) Write( + bufs [][]byte, + offset int, +) (int, error) { + return 0, nil +} + +func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } + +func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } + +func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } + +func (t *fakeTUNDeviceSized) Close() error { return nil } + +func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } func TestBatchSize(t *testing.T) { d := Device{} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 44bc1ac..e7ad927 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -15,7 +15,6 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" - "golang.zx2c4.com/wireguard/cfg" "golang.zx2c4.com/wireguard/tai64n" ) @@ -53,11 +52,11 @@ const ( WGLabelCookie = "cookie--" ) -const ( - MessageInitiationType = cfg.InitPacketMagicHeader - MessageResponseType = cfg.ResponsePacketMagicHeader - MessageCookieReplyType = cfg.UnderloadPacketMagicHeader - MessageTransportType = cfg.TransportPacketMagicHeader +var ( + MessageInitiationType uint32 = 1 + MessageResponseType uint32 = 2 + MessageCookieReplyType uint32 = 3 + MessageTransportType uint32 = 4 ) const ( @@ -76,19 +75,9 @@ const ( MessageTransportOffsetContent = 16 ) -var packetSizeToMsgType = map[int]uint32{ - MessageInitiationSize + cfg.InitPacketJunkSize: MessageInitiationType, - MessageResponseSize + cfg.ResponsePacketJunkSize: MessageResponseType, - MessageCookieReplySize + cfg.UnderLoadPacketJunkSize: MessageCookieReplyType, - MessageTransportSize + cfg.TransportPacketJunkSize: MessageTransportType, -} +var packetSizeToMsgType map[int]uint32 -var msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: cfg.InitPacketJunkSize, - MessageResponseType: cfg.ResponsePacketJunkSize, - MessageCookieReplyType: cfg.UnderLoadPacketJunkSize, - MessageTransportType: cfg.TransportPacketJunkSize, -} +var msgTypeToJunkSize map[uint32]int /* Type is an 8-bit field, followed by 3 nul bytes, * by marshalling the messages in little-endian byteorder diff --git a/device/receive.go b/device/receive.go index 9112ed0..4714175 100644 --- a/device/receive.go +++ b/device/receive.go @@ -16,7 +16,6 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/cfg" "golang.zx2c4.com/wireguard/conn" ) @@ -142,12 +141,10 @@ func (device *Device) RoutineReceiveIncoming( // check size of packet packet := bufsArrs[i][:size] - if cfg.IsAdvancedSecurityOn() { + if device.isAdvancedSecurityOn() { var junkSize int - if mapMsgType, ok := packetSizeToMsgType[size]; ok { - junkSize = msgTypeToJunkSize[mapMsgType] - } else { - junkSize = cfg.TransportPacketJunkSize + if msgType, ok := packetSizeToMsgType[size]; ok { + junkSize = msgTypeToJunkSize[msgType] } // shift junk packet = packet[junkSize:] diff --git a/device/send.go b/device/send.go index 9cbb7ab..8691f76 100644 --- a/device/send.go +++ b/device/send.go @@ -19,7 +19,6 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/cfg" "golang.zx2c4.com/wireguard/tun" ) @@ -128,15 +127,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } // so only packet processed for cookie generation var junkedHeader []byte - if cfg.IsAdvancedSecurityOn() { + if peer.device.isAdvancedSecurityOn() { err = peer.sendJunkPackets() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) return err } - var buf [cfg.InitPacketJunkSize]byte + buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = appendJunk(writer, cfg.InitPacketJunkSize) + err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) return err @@ -183,10 +182,10 @@ func (peer *Peer) SendHandshakeResponse() error { return err } var junkedHeader []byte - if cfg.IsAdvancedSecurityOn() { - var buf [cfg.ResponsePacketJunkSize]byte + if peer.device.isAdvancedSecurityOn() { + buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = appendJunk(writer, cfg.ResponsePacketJunkSize) + err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) return err @@ -472,11 +471,11 @@ top: } func (peer *Peer) sendJunkPackets() error { - junks := make([][]byte, 0, cfg.JunkPacketCount) - for i := 0; i < cfg.JunkPacketCount; i++ { + junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount) + for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ { packetSize := rand.Intn( - cfg.JunkPacketMaxSize-cfg.JunkPacketMinSize, - ) + cfg.JunkPacketMinSize + peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize, + ) + peer.device.aSecCfg.junkPacketMinSize junk, err := randomJunkWithSize(packetSize) if err != nil { diff --git a/device/uapi.go b/device/uapi.go index 617dcd3..acec80a 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "log" "net" "net/netip" "strconv" @@ -97,6 +98,36 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } + if device.isAdvancedSecurityOn() { + if device.aSecCfg.junkPacketCount != 0 { + sendf("jc=%d", device.aSecCfg.junkPacketCount) + } + if device.aSecCfg.junkPacketMinSize != 0 { + sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) + } + if device.aSecCfg.junkPacketMaxSize != 0 { + sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) + } + if device.aSecCfg.initPacketJunkSize != 0 { + sendf("s1=%d", device.aSecCfg.initPacketJunkSize) + } + if device.aSecCfg.responsePacketJunkSize != 0 { + sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) + } + if device.aSecCfg.initPacketMagicHeader != 0 { + sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) + } + if device.aSecCfg.responsePacketMagicHeader != 0 { + sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) + } + if device.aSecCfg.underloadPacketMagicHeader != 0 { + sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) + } + if device.aSecCfg.transportPacketMagicHeader != 0 { + sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) + } + } + for _, peer := range device.peers.keyMap { // Serialize peer state. // Do the work in an anonymous function so that we can use defer. @@ -119,12 +150,18 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("last_handshake_time_nsec=%d", nano) sendf("tx_bytes=%d", peer.txBytes.Load()) sendf("rx_bytes=%d", peer.rxBytes.Load()) - sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) + sendf( + "persistent_keepalive_interval=%d", + peer.persistentKeepaliveInterval.Load(), + ) - device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { - sendf("allowed_ip=%s", prefix.String()) - return true - }) + device.allowedips.EntriesForPeer( + peer, + func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }, + ) }() } }() @@ -162,7 +199,11 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { } key, value, ok := strings.Cut(line, "=") if !ok { - return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line) + return ipcErrorf( + ipc.IpcErrorProtocol, + "failed to parse line %q", + line, + ) } if key == "public_key" { @@ -188,6 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } + device.handlePostConfig() peer.handlePostConfig() if err := scanner.Err(); err != nil { @@ -202,7 +244,11 @@ func (device *Device) handleDeviceLine(key, value string) error { var sk NoisePrivateKey err := sk.FromMaybeZeroHex(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set private_key: %w", + err, + ) } device.log.Verbosef("UAPI: Updating private key") device.SetPrivateKey(sk) @@ -210,7 +256,11 @@ func (device *Device) handleDeviceLine(key, value string) error { case "listen_port": port, err := strconv.ParseUint(value, 10, 16) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to parse listen_port: %w", + err, + ) } // update port and rebind @@ -221,7 +271,11 @@ func (device *Device) handleDeviceLine(key, value string) error { device.net.Unlock() if err := device.BindUpdate(); err != nil { - return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) + return ipcErrorf( + ipc.IpcErrorPortInUse, + "failed to set listen_port: %w", + err, + ) } case "fwmark": @@ -232,18 +286,171 @@ func (device *Device) handleDeviceLine(key, value string) error { device.log.Verbosef("UAPI: Updating fwmark") if err := device.BindSetMark(uint32(mark)); err != nil { - return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) + return ipcErrorf( + ipc.IpcErrorPortInUse, + "failed to update fwmark: %w", + err, + ) } case "replace_peers": if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set replace_peers, invalid value: %v", + value, + ) } device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() + case "jc": + junkPacketCount, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse junk_packet_count %w", + err, + ) + } + if junkPacketCount < 0 { + log.Fatalf("JunkPacketCount should be non negative") + } + device.log.Verbosef("UAPI: Updating junk_packet_count") + device.aSecCfg.isOn = true + device.aSecCfg.junkPacketCount = junkPacketCount + case "jmin": + junkPacketMinSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse junk_packet_min_size %w", + err, + ) + } + device.log.Verbosef("UAPI: Updating junk_packet_min_size") + device.aSecCfg.isOn = true + device.aSecCfg.junkPacketMinSize = junkPacketMinSize + case "jmax": + junkPacketMaxSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse junk_packet_max_size %w", + err, + ) + } + if junkPacketMaxSize >= MaxSegmentSize { + log.Fatalf( + "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", + junkPacketMaxSize, + MaxSegmentSize, + ) + } + device.log.Verbosef("UAPI: Updating junk_packet_max_size") + device.aSecCfg.isOn = true + device.aSecCfg.junkPacketMaxSize = junkPacketMaxSize + case "s1": + initPacketJunkSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse init_packet_junk_size %w", + err, + ) + } + if 148+initPacketJunkSize >= MaxSegmentSize { + log.Fatalf( + `init header size(148) + junkSize:%d; + should be smaller than maxSegmentSize: %d`, + initPacketJunkSize, + MaxSegmentSize, + ) + } + device.log.Verbosef("UAPI: Updating init_packet_junk_size") + device.aSecCfg.isOn = true + device.aSecCfg.initPacketJunkSize = initPacketJunkSize + case "s2": + responsePacketJunkSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse response_packet_junk_size %w", + err, + ) + } + if 92+responsePacketJunkSize >= MaxSegmentSize { + log.Fatalf( + `response header size(92) + junkSize:%d; + should be smaller than maxSegmentSize: %d`, + responsePacketJunkSize, + MaxSegmentSize, + ) + } + device.log.Verbosef("UAPI: Updating response_packet_junk_size") + device.aSecCfg.isOn = true + device.aSecCfg.responsePacketJunkSize = responsePacketJunkSize + + case "h1": + initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse init_packet_magic_header %w", + err, + ) + } + device.log.Verbosef("UAPI: Updating init_packet_magic_header") + device.aSecCfg.isOn = true + device.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) + case "h2": + responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse response_packet_magic_header %w", + err, + ) + } + device.log.Verbosef("UAPI: Updating response_packet_magic_header") + device.aSecCfg.isOn = true + device.aSecCfg.responsePacketMagicHeader = uint32( + responsePacketMagicHeader, + ) + case "h3": + underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse underload_packet_magic_header %w", + err, + ) + } + device.log.Verbosef("UAPI: Updating underload_packet_magic_header") + device.aSecCfg.isOn = true + device.aSecCfg.underloadPacketMagicHeader = uint32( + underloadPacketMagicHeader, + ) + case "h4": + transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return ipcErrorf( + ipc.IpcErrorInvalid, + "faield to parse transport_packet_magic_header %w", + err, + ) + } + device.log.Verbosef("UAPI: Updating transport_packet_magic_header") + device.aSecCfg.isOn = true + device.aSecCfg.transportPacketMagicHeader = uint32( + transportPacketMagicHeader, + ) default: - return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) + return ipcErrorf( + ipc.IpcErrorInvalid, + "invalid UAPI device key: %v", + key, + ) } return nil @@ -262,7 +469,8 @@ func (peer *ipcSetPeer) handlePostConfig() { return } if peer.created { - peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil + peer.disableRoaming = peer.device.net.brokenRoaming && + peer.endpoint != nil } if peer.device.isUp() { peer.Start() @@ -273,12 +481,19 @@ func (peer *ipcSetPeer) handlePostConfig() { } } -func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { +func (device *Device) handlePublicKeyLine( + peer *ipcSetPeer, + value string, +) error { // Load/create the peer we are configuring. var publicKey NoisePublicKey err := publicKey.FromHex(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to get peer by public key: %w", + err, + ) } // Ignore peer with the same public key as this device. @@ -296,19 +511,30 @@ func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error if peer.created { peer.Peer, err = device.NewPeer(publicKey) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to create new peer: %w", + err, + ) } device.log.Verbosef("%v - UAPI: Created", peer.Peer) } return nil } -func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { +func (device *Device) handlePeerLine( + peer *ipcSetPeer, + key, value string, +) error { switch key { case "update_only": // allow disabling of creation if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set update only, invalid value: %v", + value, + ) } if peer.created && !peer.dummy { device.RemovePeer(peer.handshake.remoteStatic) @@ -319,7 +545,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error case "remove": // remove currently selected peer from device if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set remove, invalid value: %v", + value, + ) } if !peer.dummy { device.log.Verbosef("%v - UAPI: Removing", peer.Peer) @@ -336,25 +566,41 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error peer.handshake.mutex.Unlock() if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set preshared key: %w", + err, + ) } case "endpoint": device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) endpoint, err := device.net.bind.ParseEndpoint(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set endpoint %v: %w", + value, + err, + ) } peer.Lock() defer peer.Unlock() peer.endpoint = endpoint case "persistent_keepalive_interval": - device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) + device.log.Verbosef( + "%v - UAPI: Updating persistent keepalive interval", + peer.Peer, + ) secs, err := strconv.ParseUint(value, 10, 16) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set persistent keepalive interval: %w", + err, + ) } old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) @@ -365,7 +611,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error case "replace_allowed_ips": device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to replace allowedips, invalid value: %v", + value, + ) } if peer.dummy { return nil @@ -376,7 +626,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) prefix, err := netip.ParsePrefix(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set allowed ip: %w", + err, + ) } if peer.dummy { return nil @@ -385,7 +639,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error case "protocol_version": if value != "1" { - return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "invalid protocol version: %v", + value, + ) } default: @@ -433,7 +691,11 @@ func (device *Device) IpcHandle(socket net.Conn) { return } if nextByte != '\n' { - err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) + err = ipcErrorf( + ipc.IpcErrorInvalid, + "trailing character in UAPI get: %q", + nextByte, + ) break } err = device.IpcGetOperation(buffered.Writer) diff --git a/device/util_test.go b/device/util_test.go index 0a936cf..c061eef 100644 --- a/device/util_test.go +++ b/device/util_test.go @@ -4,8 +4,6 @@ import ( "bytes" "fmt" "testing" - - "golang.zx2c4.com/wireguard/cfg" ) func Test_randomJunktWithSize(t *testing.T) { @@ -19,7 +17,7 @@ func Test_appendJunk(t *testing.T) { buffer := bytes.NewBuffer([]byte(s)) err := appendJunk(buffer, 30) if err != nil && - buffer.Len() != len(s)+int(cfg.InitPacketJunkSize) { + buffer.Len() != len(s)+30 { t.Errorf("appendWithJunk() size don't match") } read := make([]byte, 50)