mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-06-07 22:03:44 +02:00
refactor cfg & integrate
Signed-off-by: Mark Puha <marko10@inf.elte.hu>
This commit is contained in:
parent
999cc3212d
commit
6c2c04ebd4
10 changed files with 573 additions and 197 deletions
|
@ -1,75 +0,0 @@
|
||||||
package cfg
|
|
||||||
|
|
||||||
import "log"
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if IsAdvancedSecurityOn() {
|
|
||||||
if JunkPacketCount < 0 {
|
|
||||||
log.Fatalf("JunkPacketCount should be non negative")
|
|
||||||
}
|
|
||||||
if JunkPacketMaxSize <= JunkPacketMinSize {
|
|
||||||
log.Fatalf(
|
|
||||||
"MaxSize: %d; should be greater than MinSize: %d",
|
|
||||||
JunkPacketMaxSize,
|
|
||||||
JunkPacketMinSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
const MaxSegmentSize = 2048 - 32
|
|
||||||
if JunkPacketMaxSize >= MaxSegmentSize {
|
|
||||||
log.Fatalf(
|
|
||||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
|
||||||
JunkPacketMaxSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if 148+InitPacketJunkSize >= MaxSegmentSize {
|
|
||||||
log.Fatalf(
|
|
||||||
"init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d",
|
|
||||||
InitPacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if 92+ResponsePacketJunkSize >= MaxSegmentSize {
|
|
||||||
log.Fatalf(
|
|
||||||
"response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d",
|
|
||||||
ResponsePacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if 64+UnderLoadPacketJunkSize >= MaxSegmentSize {
|
|
||||||
log.Fatalf(
|
|
||||||
"underload packet size(64) + junkSize:%d; should be smaller than maxSegmentSize: %d",
|
|
||||||
UnderLoadPacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if 32+TransportPacketJunkSize >= MaxSegmentSize {
|
|
||||||
log.Fatalf(
|
|
||||||
"transport packet size(32) + junkSize:%d should be smaller than maxSegmentSize: %d",
|
|
||||||
TransportPacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if UnderLoadPacketJunkSize != 0 || TransportPacketJunkSize != 0 {
|
|
||||||
log.Fatal(
|
|
||||||
`UnderLoadPacketJunkSize and TransportPacketJunkSize;
|
|
||||||
are currently unimplemented and should be left 0`,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if InitPacketJunkSize != 0 ||
|
|
||||||
ResponsePacketJunkSize != 0 ||
|
|
||||||
UnderLoadPacketJunkSize != 0 ||
|
|
||||||
TransportPacketJunkSize != 0 {
|
|
||||||
|
|
||||||
log.Fatal("JunkSizes should be zero when advanced security on")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsAdvancedSecurityOn() bool {
|
|
||||||
return InitPacketMagicHeader != 1 ||
|
|
||||||
ResponsePacketMagicHeader != 2 ||
|
|
||||||
UnderloadPacketMagicHeader != 3 ||
|
|
||||||
TransportPacketMagicHeader != 4
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
junk_packet_count: 5
|
|
||||||
junk_packet_min_size: 10
|
|
||||||
junk_packet_max_size: 30
|
|
||||||
init_packet_junk_size: 0
|
|
||||||
response_packet_junk_size: 0
|
|
||||||
underload_packet_junk_size: 0
|
|
||||||
transport_packet_junk_size: 0
|
|
||||||
init_packet_magic_header: 1
|
|
||||||
response_packet_magic_header : 2
|
|
||||||
underload_packet_magic_header : 3
|
|
||||||
transport_packet_magic_header : 4
|
|
|
@ -1,11 +0,0 @@
|
||||||
junk_packet_count: 5
|
|
||||||
junk_packet_min_size: 10
|
|
||||||
junk_packet_max_size: 30
|
|
||||||
init_packet_junk_size: 30
|
|
||||||
response_packet_junk_size: 50
|
|
||||||
underload_packet_junk_size: 0
|
|
||||||
transport_packet_junk_size: 0
|
|
||||||
init_packet_magic_header: 1234567
|
|
||||||
response_packet_magic_header : 7654321
|
|
||||||
underload_packet_magic_header : 12345687
|
|
||||||
transport_packet_magic_header : 146810
|
|
|
@ -6,6 +6,7 @@
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -89,6 +90,18 @@ type Device struct {
|
||||||
ipcMutex sync.RWMutex
|
ipcMutex sync.RWMutex
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
log *Logger
|
log *Logger
|
||||||
|
aSecCfg struct {
|
||||||
|
isOn bool
|
||||||
|
junkPacketCount int
|
||||||
|
junkPacketMinSize int
|
||||||
|
junkPacketMaxSize int
|
||||||
|
initPacketJunkSize int
|
||||||
|
responsePacketJunkSize int
|
||||||
|
initPacketMagicHeader uint32
|
||||||
|
responsePacketMagicHeader uint32
|
||||||
|
underloadPacketMagicHeader uint32
|
||||||
|
transportPacketMagicHeader uint32
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// deviceState represents the state of a Device.
|
// deviceState represents the state of a Device.
|
||||||
|
@ -142,7 +155,10 @@ func (device *Device) changeState(want deviceState) (err error) {
|
||||||
old := device.deviceState()
|
old := device.deviceState()
|
||||||
if old == deviceStateClosed {
|
if old == deviceStateClosed {
|
||||||
// once closed, always closed
|
// once closed, always closed
|
||||||
device.log.Verbosef("Interface closed, ignored requested state %s", want)
|
device.log.Verbosef(
|
||||||
|
"Interface closed, ignored requested state %s",
|
||||||
|
want,
|
||||||
|
)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
switch want {
|
switch want {
|
||||||
|
@ -162,7 +178,12 @@ func (device *Device) changeState(want deviceState) (err error) {
|
||||||
err = errDown
|
err = errDown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
device.log.Verbosef(
|
||||||
|
"Interface state was %s, requested %s, now %s",
|
||||||
|
old,
|
||||||
|
want,
|
||||||
|
device.deviceState(),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,7 +288,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(
|
||||||
|
handshake.remoteStatic,
|
||||||
|
)
|
||||||
expiredPeers = append(expiredPeers, peer)
|
expiredPeers = append(expiredPeers, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -411,7 +434,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.keypairs.RLock()
|
peer.keypairs.RLock()
|
||||||
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
|
sendKeepalive := peer.keypairs.current != nil &&
|
||||||
|
!peer.keypairs.current.created.Add(RejectAfterTime).
|
||||||
|
Before(time.Now())
|
||||||
peer.keypairs.RUnlock()
|
peer.keypairs.RUnlock()
|
||||||
if sendKeepalive {
|
if sendKeepalive {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
|
@ -525,8 +550,12 @@ func (device *Device) BindUpdate() error {
|
||||||
|
|
||||||
// start receiving routines
|
// start receiving routines
|
||||||
device.net.stopping.Add(len(recvFns))
|
device.net.stopping.Add(len(recvFns))
|
||||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
device.queue.decryption.wg.Add(
|
||||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
len(recvFns),
|
||||||
|
) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||||
|
device.queue.handshake.wg.Add(
|
||||||
|
len(recvFns),
|
||||||
|
) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||||
batchSize := netc.bind.BatchSize()
|
batchSize := netc.bind.BatchSize()
|
||||||
for _, fn := range recvFns {
|
for _, fn := range recvFns {
|
||||||
go device.RoutineReceiveIncoming(batchSize, fn)
|
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||||
|
@ -542,3 +571,57 @@ func (device *Device) BindClose() error {
|
||||||
device.net.Unlock()
|
device.net.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
func (device *Device) isAdvancedSecurityOn() bool {
|
||||||
|
return device.aSecCfg.isOn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) handlePostConfig() {
|
||||||
|
if device.isAdvancedSecurityOn() {
|
||||||
|
if device.aSecCfg.junkPacketMaxSize >= 0 {
|
||||||
|
if device.aSecCfg.junkPacketMaxSize == device.aSecCfg.junkPacketMinSize {
|
||||||
|
device.aSecCfg.junkPacketMaxSize++ // to make rand gen work
|
||||||
|
} else if device.aSecCfg.junkPacketMaxSize < device.aSecCfg.junkPacketMinSize {
|
||||||
|
log.Fatalf(
|
||||||
|
"MaxSize: %d; should be greater than MinSize: %d",
|
||||||
|
device.aSecCfg.junkPacketMaxSize,
|
||||||
|
device.aSecCfg.junkPacketMinSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.aSecCfg.initPacketMagicHeader != 0 &&
|
||||||
|
device.aSecCfg.initPacketMagicHeader != 1 {
|
||||||
|
|
||||||
|
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
|
||||||
|
}
|
||||||
|
if device.aSecCfg.responsePacketMagicHeader != 0 &&
|
||||||
|
device.aSecCfg.responsePacketMagicHeader != 1 {
|
||||||
|
|
||||||
|
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
|
||||||
|
}
|
||||||
|
if device.aSecCfg.underloadPacketMagicHeader != 0 &&
|
||||||
|
device.aSecCfg.underloadPacketMagicHeader != 1 {
|
||||||
|
|
||||||
|
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
|
||||||
|
}
|
||||||
|
if device.aSecCfg.transportPacketMagicHeader != 0 &&
|
||||||
|
device.aSecCfg.transportPacketMagicHeader != 1 {
|
||||||
|
|
||||||
|
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
packetSizeToMsgType = map[int]uint32{
|
||||||
|
MessageInitiationSize + device.aSecCfg.initPacketJunkSize: MessageInitiationType,
|
||||||
|
MessageResponseSize + device.aSecCfg.responsePacketJunkSize: MessageResponseType,
|
||||||
|
MessageCookieReplySize: MessageCookieReplyType,
|
||||||
|
MessageTransportSize: MessageTransportType,
|
||||||
|
}
|
||||||
|
|
||||||
|
msgTypeToJunkSize = map[uint32]int{
|
||||||
|
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
|
||||||
|
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
|
||||||
|
MessageCookieReplyType: 0,
|
||||||
|
MessageTransportType: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -91,6 +91,65 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
|
var key1, key2 NoisePrivateKey
|
||||||
|
_, err := rand.Read(key1[:])
|
||||||
|
if err != nil {
|
||||||
|
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||||
|
}
|
||||||
|
_, err = rand.Read(key2[:])
|
||||||
|
if err != nil {
|
||||||
|
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||||
|
}
|
||||||
|
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||||
|
|
||||||
|
cfgs[0] = uapiCfg(
|
||||||
|
"private_key", hex.EncodeToString(key1[:]),
|
||||||
|
"listen_port", "0",
|
||||||
|
"replace_peers", "true",
|
||||||
|
"jc", "5",
|
||||||
|
"jmin", "500",
|
||||||
|
"jmax", "501",
|
||||||
|
"s1", "30",
|
||||||
|
"s2", "40",
|
||||||
|
"h1", "123456",
|
||||||
|
"h2", "67543",
|
||||||
|
"h4", "32345",
|
||||||
|
"h3", "123123",
|
||||||
|
"public_key", hex.EncodeToString(pub2[:]),
|
||||||
|
"protocol_version", "1",
|
||||||
|
"replace_allowed_ips", "true",
|
||||||
|
"allowed_ip", "1.0.0.2/32",
|
||||||
|
)
|
||||||
|
endpointCfgs[0] = uapiCfg(
|
||||||
|
"public_key", hex.EncodeToString(pub2[:]),
|
||||||
|
"endpoint", "127.0.0.1:%d",
|
||||||
|
)
|
||||||
|
cfgs[1] = uapiCfg(
|
||||||
|
"private_key", hex.EncodeToString(key2[:]),
|
||||||
|
"listen_port", "0",
|
||||||
|
"replace_peers", "true",
|
||||||
|
"jc", "5",
|
||||||
|
"jmin", "500",
|
||||||
|
"jmax", "501",
|
||||||
|
"s1", "30",
|
||||||
|
"s2", "40",
|
||||||
|
"h1", "123456",
|
||||||
|
"h2", "67543",
|
||||||
|
"h4", "32345",
|
||||||
|
"h3", "123123",
|
||||||
|
"public_key", hex.EncodeToString(pub1[:]),
|
||||||
|
"protocol_version", "1",
|
||||||
|
"replace_allowed_ips", "true",
|
||||||
|
"allowed_ip", "1.0.0.1/32",
|
||||||
|
)
|
||||||
|
endpointCfgs[1] = uapiCfg(
|
||||||
|
"public_key", hex.EncodeToString(pub1[:]),
|
||||||
|
"endpoint", "127.0.0.1:%d",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// A testPair is a pair of testPeers.
|
// A testPair is a pair of testPeers.
|
||||||
type testPair [2]testPeer
|
type testPair [2]testPeer
|
||||||
|
|
||||||
|
@ -115,7 +174,11 @@ func (d SendDirection) String() string {
|
||||||
return "pong"
|
return "pong"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
|
func (pair *testPair) Send(
|
||||||
|
tb testing.TB,
|
||||||
|
ping SendDirection,
|
||||||
|
done chan struct{},
|
||||||
|
) {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
p0, p1 := pair[0], pair[1]
|
p0, p1 := pair[0], pair[1]
|
||||||
if !ping {
|
if !ping {
|
||||||
|
@ -149,8 +212,16 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// genTestPair creates a testPair.
|
// genTestPair creates a testPair.
|
||||||
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
func genTestPair(
|
||||||
cfg, endpointCfg := genConfigs(tb)
|
tb testing.TB,
|
||||||
|
realSocket, withASecurity bool,
|
||||||
|
) (pair testPair) {
|
||||||
|
var cfg, endpointCfg [2]string
|
||||||
|
if withASecurity {
|
||||||
|
cfg, endpointCfg = genASecurityConfigs(tb)
|
||||||
|
} else {
|
||||||
|
cfg, endpointCfg = genConfigs(tb)
|
||||||
|
}
|
||||||
var binds [2]conn.Bind
|
var binds [2]conn.Bind
|
||||||
if realSocket {
|
if realSocket {
|
||||||
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
||||||
|
@ -166,7 +237,11 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
||||||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||||
level = LogLevelError
|
level = LogLevelError
|
||||||
}
|
}
|
||||||
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
p.dev = NewDevice(
|
||||||
|
p.tun.TUN(),
|
||||||
|
binds[i],
|
||||||
|
NewLogger(level, fmt.Sprintf("dev%d: ", i)),
|
||||||
|
)
|
||||||
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
||||||
tb.Errorf("failed to configure device %d: %v", i, err)
|
tb.Errorf("failed to configure device %d: %v", i, err)
|
||||||
p.dev.Close()
|
p.dev.Close()
|
||||||
|
@ -194,7 +269,18 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
||||||
|
|
||||||
func TestTwoDevicePing(t *testing.T) {
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
goroutineLeakCheck(t)
|
goroutineLeakCheck(t)
|
||||||
pair := genTestPair(t, true)
|
pair := genTestPair(t, true, false)
|
||||||
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
|
pair.Send(t, Ping, nil)
|
||||||
|
})
|
||||||
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
|
pair.Send(t, Pong, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoDevicePingASecurity(t *testing.T) {
|
||||||
|
goroutineLeakCheck(t)
|
||||||
|
pair := genTestPair(t, true, true)
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
pair.Send(t, Ping, nil)
|
pair.Send(t, Ping, nil)
|
||||||
})
|
})
|
||||||
|
@ -209,10 +295,15 @@ func TestUpDown(t *testing.T) {
|
||||||
const otrials = 10
|
const otrials = 10
|
||||||
|
|
||||||
for n := 0; n < otrials; n++ {
|
for n := 0; n < otrials; n++ {
|
||||||
pair := genTestPair(t, false)
|
pair := genTestPair(t, false, false)
|
||||||
for i := range pair {
|
for i := range pair {
|
||||||
for k := range pair[i].dev.peers.keyMap {
|
for k := range pair[i].dev.peers.keyMap {
|
||||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
pair[i].dev.IpcSet(
|
||||||
|
fmt.Sprintf(
|
||||||
|
"public_key=%s\npersistent_keepalive_interval=1\n",
|
||||||
|
hex.EncodeToString(k[:]),
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
@ -224,11 +315,19 @@ func TestUpDown(t *testing.T) {
|
||||||
if err := d.Up(); err != nil {
|
if err := d.Up(); err != nil {
|
||||||
t.Errorf("failed up bring up device: %v", err)
|
t.Errorf("failed up bring up device: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
|
time.Sleep(
|
||||||
|
time.Duration(
|
||||||
|
rand.Intn(int(time.Nanosecond * (0x10000 - 1))),
|
||||||
|
),
|
||||||
|
)
|
||||||
if err := d.Down(); err != nil {
|
if err := d.Down(); err != nil {
|
||||||
t.Errorf("failed to bring down device: %v", err)
|
t.Errorf("failed to bring down device: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
|
time.Sleep(
|
||||||
|
time.Duration(
|
||||||
|
rand.Intn(int(time.Nanosecond * (0x10000 - 1))),
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}(pair[i].dev)
|
}(pair[i].dev)
|
||||||
}
|
}
|
||||||
|
@ -243,7 +342,7 @@ func TestUpDown(t *testing.T) {
|
||||||
// TestConcurrencySafety does other things concurrently with tunnel use.
|
// TestConcurrencySafety does other things concurrently with tunnel use.
|
||||||
// It is intended to be used with the race detector to catch data races.
|
// It is intended to be used with the race detector to catch data races.
|
||||||
func TestConcurrencySafety(t *testing.T) {
|
func TestConcurrencySafety(t *testing.T) {
|
||||||
pair := genTestPair(t, true)
|
pair := genTestPair(t, true, false)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
||||||
const warmupIters = 10
|
const warmupIters = 10
|
||||||
|
@ -294,8 +393,14 @@ func TestConcurrencySafety(t *testing.T) {
|
||||||
|
|
||||||
// Change private keys concurrently with tunnel use.
|
// Change private keys concurrently with tunnel use.
|
||||||
t.Run("privateKey", func(t *testing.T) {
|
t.Run("privateKey", func(t *testing.T) {
|
||||||
bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
|
bad := uapiCfg(
|
||||||
good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
|
"private_key",
|
||||||
|
"7777777777777777777777777777777777777777777777777777777777777777",
|
||||||
|
)
|
||||||
|
good := uapiCfg(
|
||||||
|
"private_key",
|
||||||
|
hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]),
|
||||||
|
)
|
||||||
// Set iters to a large number like 1000 to flush out data races quickly.
|
// Set iters to a large number like 1000 to flush out data races quickly.
|
||||||
// Don't leave it large. That can cause logical races
|
// Don't leave it large. That can cause logical races
|
||||||
// in which the handshake is interleaved with key changes
|
// in which the handshake is interleaved with key changes
|
||||||
|
@ -324,7 +429,7 @@ func TestConcurrencySafety(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLatency(b *testing.B) {
|
func BenchmarkLatency(b *testing.B) {
|
||||||
pair := genTestPair(b, true)
|
pair := genTestPair(b, true, false)
|
||||||
|
|
||||||
// Establish a connection.
|
// Establish a connection.
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
|
@ -338,7 +443,7 @@ func BenchmarkLatency(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkThroughput(b *testing.B) {
|
func BenchmarkThroughput(b *testing.B) {
|
||||||
pair := genTestPair(b, true)
|
pair := genTestPair(b, true, false)
|
||||||
|
|
||||||
// Establish a connection.
|
// Establish a connection.
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
|
@ -382,7 +487,7 @@ func BenchmarkThroughput(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkUAPIGet(b *testing.B) {
|
func BenchmarkUAPIGet(b *testing.B) {
|
||||||
pair := genTestPair(b, true)
|
pair := genTestPair(b, true, false)
|
||||||
pair.Send(b, Ping, nil)
|
pair.Send(b, Ping, nil)
|
||||||
pair.Send(b, Pong, nil)
|
pair.Send(b, Pong, nil)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
@ -415,7 +520,11 @@ func goroutineLeakCheck(t *testing.T) {
|
||||||
endGoroutines, endStacks := goroutines()
|
endGoroutines, endStacks := goroutines()
|
||||||
t.Logf("starting stacks:\n%s\n", startStacks)
|
t.Logf("starting stacks:\n%s\n", startStacks)
|
||||||
t.Logf("ending stacks:\n%s\n", endStacks)
|
t.Logf("ending stacks:\n%s\n", endStacks)
|
||||||
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
|
t.Fatalf(
|
||||||
|
"expected %d goroutines, got %d, leak?",
|
||||||
|
startGoroutines,
|
||||||
|
endGoroutines,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -423,29 +532,65 @@ type fakeBindSized struct {
|
||||||
size int
|
size int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
func (b *fakeBindSized) Open(
|
||||||
|
port uint16,
|
||||||
|
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||||
return nil, 0, nil
|
return nil, 0, nil
|
||||||
}
|
}
|
||||||
func (b *fakeBindSized) Close() error { return nil }
|
|
||||||
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
func (b *fakeBindSized) Close() error { return nil }
|
||||||
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
|
||||||
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
|
func (b *fakeBindSized) SetMark(
|
||||||
func (b *fakeBindSized) BatchSize() int { return b.size }
|
mark uint32,
|
||||||
|
) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *fakeBindSized) Send(
|
||||||
|
bufs [][]byte,
|
||||||
|
ep conn.Endpoint,
|
||||||
|
) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *fakeBindSized) ParseEndpoint(
|
||||||
|
s string,
|
||||||
|
) (conn.Endpoint, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *fakeBindSized) BatchSize() int { return b.size }
|
||||||
|
|
||||||
type fakeTUNDeviceSized struct {
|
type fakeTUNDeviceSized struct {
|
||||||
size int
|
size int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
||||||
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
|
||||||
|
func (t *fakeTUNDeviceSized) Read(
|
||||||
|
bufs [][]byte,
|
||||||
|
sizes []int,
|
||||||
|
offset int,
|
||||||
|
) (n int, err error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
|
||||||
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
|
func (t *fakeTUNDeviceSized) Write(
|
||||||
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
|
bufs [][]byte,
|
||||||
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
|
offset int,
|
||||||
func (t *fakeTUNDeviceSized) Close() error { return nil }
|
) (int, error) {
|
||||||
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) Close() error { return nil }
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
|
||||||
|
|
||||||
func TestBatchSize(t *testing.T) {
|
func TestBatchSize(t *testing.T) {
|
||||||
d := Device{}
|
d := Device{}
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/poly1305"
|
"golang.org/x/crypto/poly1305"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/cfg"
|
|
||||||
"golang.zx2c4.com/wireguard/tai64n"
|
"golang.zx2c4.com/wireguard/tai64n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,11 +52,11 @@ const (
|
||||||
WGLabelCookie = "cookie--"
|
WGLabelCookie = "cookie--"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
MessageInitiationType = cfg.InitPacketMagicHeader
|
MessageInitiationType uint32 = 1
|
||||||
MessageResponseType = cfg.ResponsePacketMagicHeader
|
MessageResponseType uint32 = 2
|
||||||
MessageCookieReplyType = cfg.UnderloadPacketMagicHeader
|
MessageCookieReplyType uint32 = 3
|
||||||
MessageTransportType = cfg.TransportPacketMagicHeader
|
MessageTransportType uint32 = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -76,19 +75,9 @@ const (
|
||||||
MessageTransportOffsetContent = 16
|
MessageTransportOffsetContent = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
var packetSizeToMsgType = map[int]uint32{
|
var packetSizeToMsgType map[int]uint32
|
||||||
MessageInitiationSize + cfg.InitPacketJunkSize: MessageInitiationType,
|
|
||||||
MessageResponseSize + cfg.ResponsePacketJunkSize: MessageResponseType,
|
|
||||||
MessageCookieReplySize + cfg.UnderLoadPacketJunkSize: MessageCookieReplyType,
|
|
||||||
MessageTransportSize + cfg.TransportPacketJunkSize: MessageTransportType,
|
|
||||||
}
|
|
||||||
|
|
||||||
var msgTypeToJunkSize = map[uint32]int{
|
var msgTypeToJunkSize map[uint32]int
|
||||||
MessageInitiationType: cfg.InitPacketJunkSize,
|
|
||||||
MessageResponseType: cfg.ResponsePacketJunkSize,
|
|
||||||
MessageCookieReplyType: cfg.UnderLoadPacketJunkSize,
|
|
||||||
MessageTransportType: cfg.TransportPacketJunkSize,
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||||
* by marshalling the messages in little-endian byteorder
|
* by marshalling the messages in little-endian byteorder
|
||||||
|
|
|
@ -16,7 +16,6 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
"golang.zx2c4.com/wireguard/cfg"
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -142,12 +141,10 @@ func (device *Device) RoutineReceiveIncoming(
|
||||||
// check size of packet
|
// check size of packet
|
||||||
|
|
||||||
packet := bufsArrs[i][:size]
|
packet := bufsArrs[i][:size]
|
||||||
if cfg.IsAdvancedSecurityOn() {
|
if device.isAdvancedSecurityOn() {
|
||||||
var junkSize int
|
var junkSize int
|
||||||
if mapMsgType, ok := packetSizeToMsgType[size]; ok {
|
if msgType, ok := packetSizeToMsgType[size]; ok {
|
||||||
junkSize = msgTypeToJunkSize[mapMsgType]
|
junkSize = msgTypeToJunkSize[msgType]
|
||||||
} else {
|
|
||||||
junkSize = cfg.TransportPacketJunkSize
|
|
||||||
}
|
}
|
||||||
// shift junk
|
// shift junk
|
||||||
packet = packet[junkSize:]
|
packet = packet[junkSize:]
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
"golang.zx2c4.com/wireguard/cfg"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -128,15 +127,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
}
|
}
|
||||||
// so only packet processed for cookie generation
|
// so only packet processed for cookie generation
|
||||||
var junkedHeader []byte
|
var junkedHeader []byte
|
||||||
if cfg.IsAdvancedSecurityOn() {
|
if peer.device.isAdvancedSecurityOn() {
|
||||||
err = peer.sendJunkPackets()
|
err = peer.sendJunkPackets()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var buf [cfg.InitPacketJunkSize]byte
|
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
err = appendJunk(writer, cfg.InitPacketJunkSize)
|
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
return err
|
return err
|
||||||
|
@ -183,10 +182,10 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var junkedHeader []byte
|
var junkedHeader []byte
|
||||||
if cfg.IsAdvancedSecurityOn() {
|
if peer.device.isAdvancedSecurityOn() {
|
||||||
var buf [cfg.ResponsePacketJunkSize]byte
|
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
err = appendJunk(writer, cfg.ResponsePacketJunkSize)
|
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
peer.device.log.Errorf("%v - %v", peer, err)
|
||||||
return err
|
return err
|
||||||
|
@ -472,11 +471,11 @@ top:
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) sendJunkPackets() error {
|
func (peer *Peer) sendJunkPackets() error {
|
||||||
junks := make([][]byte, 0, cfg.JunkPacketCount)
|
junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
|
||||||
for i := 0; i < cfg.JunkPacketCount; i++ {
|
for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
|
||||||
packetSize := rand.Intn(
|
packetSize := rand.Intn(
|
||||||
cfg.JunkPacketMaxSize-cfg.JunkPacketMinSize,
|
peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
|
||||||
) + cfg.JunkPacketMinSize
|
) + peer.device.aSecCfg.junkPacketMinSize
|
||||||
|
|
||||||
junk, err := randomJunkWithSize(packetSize)
|
junk, err := randomJunkWithSize(packetSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
316
device/uapi.go
316
device/uapi.go
|
@ -11,6 +11,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -97,6 +98,36 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
sendf("fwmark=%d", device.net.fwmark)
|
sendf("fwmark=%d", device.net.fwmark)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if device.isAdvancedSecurityOn() {
|
||||||
|
if device.aSecCfg.junkPacketCount != 0 {
|
||||||
|
sendf("jc=%d", device.aSecCfg.junkPacketCount)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.junkPacketMinSize != 0 {
|
||||||
|
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.junkPacketMaxSize != 0 {
|
||||||
|
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.initPacketJunkSize != 0 {
|
||||||
|
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.responsePacketJunkSize != 0 {
|
||||||
|
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.initPacketMagicHeader != 0 {
|
||||||
|
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.responsePacketMagicHeader != 0 {
|
||||||
|
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.underloadPacketMagicHeader != 0 {
|
||||||
|
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
|
||||||
|
}
|
||||||
|
if device.aSecCfg.transportPacketMagicHeader != 0 {
|
||||||
|
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
// Serialize peer state.
|
// Serialize peer state.
|
||||||
// Do the work in an anonymous function so that we can use defer.
|
// Do the work in an anonymous function so that we can use defer.
|
||||||
|
@ -119,12 +150,18 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
sendf("last_handshake_time_nsec=%d", nano)
|
sendf("last_handshake_time_nsec=%d", nano)
|
||||||
sendf("tx_bytes=%d", peer.txBytes.Load())
|
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
sendf(
|
||||||
|
"persistent_keepalive_interval=%d",
|
||||||
|
peer.persistentKeepaliveInterval.Load(),
|
||||||
|
)
|
||||||
|
|
||||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
device.allowedips.EntriesForPeer(
|
||||||
sendf("allowed_ip=%s", prefix.String())
|
peer,
|
||||||
return true
|
func(prefix netip.Prefix) bool {
|
||||||
})
|
sendf("allowed_ip=%s", prefix.String())
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -162,7 +199,11 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
}
|
}
|
||||||
key, value, ok := strings.Cut(line, "=")
|
key, value, ok := strings.Cut(line, "=")
|
||||||
if !ok {
|
if !ok {
|
||||||
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorProtocol,
|
||||||
|
"failed to parse line %q",
|
||||||
|
line,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if key == "public_key" {
|
if key == "public_key" {
|
||||||
|
@ -188,6 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
device.handlePostConfig()
|
||||||
peer.handlePostConfig()
|
peer.handlePostConfig()
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
|
@ -202,7 +244,11 @@ func (device *Device) handleDeviceLine(key, value string) error {
|
||||||
var sk NoisePrivateKey
|
var sk NoisePrivateKey
|
||||||
err := sk.FromMaybeZeroHex(value)
|
err := sk.FromMaybeZeroHex(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set private_key: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("UAPI: Updating private key")
|
device.log.Verbosef("UAPI: Updating private key")
|
||||||
device.SetPrivateKey(sk)
|
device.SetPrivateKey(sk)
|
||||||
|
@ -210,7 +256,11 @@ func (device *Device) handleDeviceLine(key, value string) error {
|
||||||
case "listen_port":
|
case "listen_port":
|
||||||
port, err := strconv.ParseUint(value, 10, 16)
|
port, err := strconv.ParseUint(value, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to parse listen_port: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// update port and rebind
|
// update port and rebind
|
||||||
|
@ -221,7 +271,11 @@ func (device *Device) handleDeviceLine(key, value string) error {
|
||||||
device.net.Unlock()
|
device.net.Unlock()
|
||||||
|
|
||||||
if err := device.BindUpdate(); err != nil {
|
if err := device.BindUpdate(); err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorPortInUse,
|
||||||
|
"failed to set listen_port: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "fwmark":
|
case "fwmark":
|
||||||
|
@ -232,18 +286,171 @@ func (device *Device) handleDeviceLine(key, value string) error {
|
||||||
|
|
||||||
device.log.Verbosef("UAPI: Updating fwmark")
|
device.log.Verbosef("UAPI: Updating fwmark")
|
||||||
if err := device.BindSetMark(uint32(mark)); err != nil {
|
if err := device.BindSetMark(uint32(mark)); err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorPortInUse,
|
||||||
|
"failed to update fwmark: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "replace_peers":
|
case "replace_peers":
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set replace_peers, invalid value: %v",
|
||||||
|
value,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("UAPI: Removing all peers")
|
device.log.Verbosef("UAPI: Removing all peers")
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
|
case "jc":
|
||||||
|
junkPacketCount, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse junk_packet_count %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if junkPacketCount < 0 {
|
||||||
|
log.Fatalf("JunkPacketCount should be non negative")
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.junkPacketCount = junkPacketCount
|
||||||
|
case "jmin":
|
||||||
|
junkPacketMinSize, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse junk_packet_min_size %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.junkPacketMinSize = junkPacketMinSize
|
||||||
|
case "jmax":
|
||||||
|
junkPacketMaxSize, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse junk_packet_max_size %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if junkPacketMaxSize >= MaxSegmentSize {
|
||||||
|
log.Fatalf(
|
||||||
|
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
||||||
|
junkPacketMaxSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
|
||||||
|
case "s1":
|
||||||
|
initPacketJunkSize, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse init_packet_junk_size %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if 148+initPacketJunkSize >= MaxSegmentSize {
|
||||||
|
log.Fatalf(
|
||||||
|
`init header size(148) + junkSize:%d;
|
||||||
|
should be smaller than maxSegmentSize: %d`,
|
||||||
|
initPacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.initPacketJunkSize = initPacketJunkSize
|
||||||
|
case "s2":
|
||||||
|
responsePacketJunkSize, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse response_packet_junk_size %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if 92+responsePacketJunkSize >= MaxSegmentSize {
|
||||||
|
log.Fatalf(
|
||||||
|
`response header size(92) + junkSize:%d;
|
||||||
|
should be smaller than maxSegmentSize: %d`,
|
||||||
|
responsePacketJunkSize,
|
||||||
|
MaxSegmentSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
|
||||||
|
|
||||||
|
case "h1":
|
||||||
|
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse init_packet_magic_header %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
||||||
|
case "h2":
|
||||||
|
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse response_packet_magic_header %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.responsePacketMagicHeader = uint32(
|
||||||
|
responsePacketMagicHeader,
|
||||||
|
)
|
||||||
|
case "h3":
|
||||||
|
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse underload_packet_magic_header %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.underloadPacketMagicHeader = uint32(
|
||||||
|
underloadPacketMagicHeader,
|
||||||
|
)
|
||||||
|
case "h4":
|
||||||
|
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"faield to parse transport_packet_magic_header %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||||
|
device.aSecCfg.isOn = true
|
||||||
|
device.aSecCfg.transportPacketMagicHeader = uint32(
|
||||||
|
transportPacketMagicHeader,
|
||||||
|
)
|
||||||
default:
|
default:
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"invalid UAPI device key: %v",
|
||||||
|
key,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -262,7 +469,8 @@ func (peer *ipcSetPeer) handlePostConfig() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if peer.created {
|
if peer.created {
|
||||||
peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
|
peer.disableRoaming = peer.device.net.brokenRoaming &&
|
||||||
|
peer.endpoint != nil
|
||||||
}
|
}
|
||||||
if peer.device.isUp() {
|
if peer.device.isUp() {
|
||||||
peer.Start()
|
peer.Start()
|
||||||
|
@ -273,12 +481,19 @@ func (peer *ipcSetPeer) handlePostConfig() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
|
func (device *Device) handlePublicKeyLine(
|
||||||
|
peer *ipcSetPeer,
|
||||||
|
value string,
|
||||||
|
) error {
|
||||||
// Load/create the peer we are configuring.
|
// Load/create the peer we are configuring.
|
||||||
var publicKey NoisePublicKey
|
var publicKey NoisePublicKey
|
||||||
err := publicKey.FromHex(value)
|
err := publicKey.FromHex(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to get peer by public key: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore peer with the same public key as this device.
|
// Ignore peer with the same public key as this device.
|
||||||
|
@ -296,19 +511,30 @@ func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error
|
||||||
if peer.created {
|
if peer.created {
|
||||||
peer.Peer, err = device.NewPeer(publicKey)
|
peer.Peer, err = device.NewPeer(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to create new peer: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("%v - UAPI: Created", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Created", peer.Peer)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
|
func (device *Device) handlePeerLine(
|
||||||
|
peer *ipcSetPeer,
|
||||||
|
key, value string,
|
||||||
|
) error {
|
||||||
switch key {
|
switch key {
|
||||||
case "update_only":
|
case "update_only":
|
||||||
// allow disabling of creation
|
// allow disabling of creation
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set update only, invalid value: %v",
|
||||||
|
value,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if peer.created && !peer.dummy {
|
if peer.created && !peer.dummy {
|
||||||
device.RemovePeer(peer.handshake.remoteStatic)
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
|
@ -319,7 +545,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||||
case "remove":
|
case "remove":
|
||||||
// remove currently selected peer from device
|
// remove currently selected peer from device
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set remove, invalid value: %v",
|
||||||
|
value,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if !peer.dummy {
|
if !peer.dummy {
|
||||||
device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
|
||||||
|
@ -336,25 +566,41 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||||
peer.handshake.mutex.Unlock()
|
peer.handshake.mutex.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set preshared key: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "endpoint":
|
case "endpoint":
|
||||||
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
||||||
endpoint, err := device.net.bind.ParseEndpoint(value)
|
endpoint, err := device.net.bind.ParseEndpoint(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set endpoint %v: %w",
|
||||||
|
value,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
peer.Lock()
|
peer.Lock()
|
||||||
defer peer.Unlock()
|
defer peer.Unlock()
|
||||||
peer.endpoint = endpoint
|
peer.endpoint = endpoint
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
device.log.Verbosef(
|
||||||
|
"%v - UAPI: Updating persistent keepalive interval",
|
||||||
|
peer.Peer,
|
||||||
|
)
|
||||||
|
|
||||||
secs, err := strconv.ParseUint(value, 10, 16)
|
secs, err := strconv.ParseUint(value, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set persistent keepalive interval: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||||
|
@ -365,7 +611,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||||
case "replace_allowed_ips":
|
case "replace_allowed_ips":
|
||||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to replace allowedips, invalid value: %v",
|
||||||
|
value,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if peer.dummy {
|
if peer.dummy {
|
||||||
return nil
|
return nil
|
||||||
|
@ -376,7 +626,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||||
prefix, err := netip.ParsePrefix(value)
|
prefix, err := netip.ParsePrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"failed to set allowed ip: %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
if peer.dummy {
|
if peer.dummy {
|
||||||
return nil
|
return nil
|
||||||
|
@ -385,7 +639,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||||
|
|
||||||
case "protocol_version":
|
case "protocol_version":
|
||||||
if value != "1" {
|
if value != "1" {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
|
return ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"invalid protocol version: %v",
|
||||||
|
value,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -433,7 +691,11 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if nextByte != '\n' {
|
if nextByte != '\n' {
|
||||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
err = ipcErrorf(
|
||||||
|
ipc.IpcErrorInvalid,
|
||||||
|
"trailing character in UAPI get: %q",
|
||||||
|
nextByte,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
err = device.IpcGetOperation(buffered.Writer)
|
err = device.IpcGetOperation(buffered.Writer)
|
||||||
|
|
|
@ -4,8 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/cfg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_randomJunktWithSize(t *testing.T) {
|
func Test_randomJunktWithSize(t *testing.T) {
|
||||||
|
@ -19,7 +17,7 @@ func Test_appendJunk(t *testing.T) {
|
||||||
buffer := bytes.NewBuffer([]byte(s))
|
buffer := bytes.NewBuffer([]byte(s))
|
||||||
err := appendJunk(buffer, 30)
|
err := appendJunk(buffer, 30)
|
||||||
if err != nil &&
|
if err != nil &&
|
||||||
buffer.Len() != len(s)+int(cfg.InitPacketJunkSize) {
|
buffer.Len() != len(s)+30 {
|
||||||
t.Errorf("appendWithJunk() size don't match")
|
t.Errorf("appendWithJunk() size don't match")
|
||||||
}
|
}
|
||||||
read := make([]byte, 50)
|
read := make([]byte, 50)
|
||||||
|
|
Loading…
Add table
Reference in a new issue