From 8f1a6a10b2a74ae86920a74e7f0989e9f57df990 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Fri, 6 Oct 2023 02:11:27 +0530
Subject: [PATCH] Advanced security (#2)
* Advanced security header layer & config
---
.gitignore | 2 +-
conn/bind_windows.go | 2 +-
conn/bindtest/bindtest.go | 2 +-
device/bind_test.go | 2 +-
device/device.go | 278 ++++++++++++++++++++++++++-
device/device_test.go | 150 ++++++++++++---
device/keypair.go | 2 +-
device/noise-protocol.go | 26 ++-
device/noise_test.go | 4 +-
device/peer.go | 2 +-
device/queueconstants_android.go | 2 +-
device/queueconstants_default.go | 2 +-
device/receive.go | 59 ++++--
device/send.go | 104 ++++++++--
device/sticky_default.go | 4 +-
device/sticky_linux.go | 4 +-
device/tun.go | 2 +-
device/uapi.go | 146 ++++++++++++--
device/util.go | 25 +++
device/util_test.go | 27 +++
go.mod | 3 +-
go.sum | 2 +
ipc/namedpipe/namedpipe_test.go | 2 +-
ipc/uapi_linux.go | 2 +-
ipc/uapi_windows.go | 2 +-
main.go | 8 +-
main_windows.go | 8 +-
tun/netstack/examples/http_client.go | 6 +-
tun/netstack/examples/http_server.go | 6 +-
tun/netstack/examples/ping_client.go | 6 +-
tun/netstack/tun.go | 2 +-
tun/tcp_offload_linux.go | 2 +-
tun/tcp_offload_linux_test.go | 2 +-
tun/tun_linux.go | 4 +-
tun/tuntest/tuntest.go | 2 +-
35 files changed, 781 insertions(+), 121 deletions(-)
create mode 100644 device/util.go
create mode 100644 device/util_test.go
diff --git a/.gitignore b/.gitignore
index e460293..71549f4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1 @@
-wireguard-go
+wireguard-go
\ No newline at end of file
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index d5095e0..9bad0ee 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -17,7 +17,7 @@ import (
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/conn/winrio"
+ "github.com/amnezia-vpn/amnezia-wg/conn/winrio"
)
const (
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index 74e7add..713c371 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -12,7 +12,7 @@ import (
"net/netip"
"os"
- "golang.zx2c4.com/wireguard/conn"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
)
type ChannelBind struct {
diff --git a/device/bind_test.go b/device/bind_test.go
index 302a521..eae36c2 100644
--- a/device/bind_test.go
+++ b/device/bind_test.go
@@ -8,7 +8,7 @@ package device
import (
"errors"
- "golang.zx2c4.com/wireguard/conn"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
)
type DummyDatagram struct {
diff --git a/device/device.go b/device/device.go
index 1af9fe0..10365d1 100644
--- a/device/device.go
+++ b/device/device.go
@@ -11,10 +11,12 @@ import (
"sync/atomic"
"time"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/ratelimiter"
- "golang.zx2c4.com/wireguard/rwcancel"
- "golang.zx2c4.com/wireguard/tun"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/ipc"
+ "github.com/amnezia-vpn/amnezia-wg/ratelimiter"
+ "github.com/amnezia-vpn/amnezia-wg/rwcancel"
+ "github.com/amnezia-vpn/amnezia-wg/tun"
+ "github.com/tevino/abool/v2"
)
type Device struct {
@@ -89,6 +91,22 @@ type Device struct {
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
+
+ isASecOn abool.AtomicBool
+ aSecMux sync.RWMutex
+ aSecCfg aSecCfgType
+}
+
+type aSecCfgType struct {
+ junkPacketCount int
+ junkPacketMinSize int
+ junkPacketMaxSize int
+ initPacketJunkSize int
+ responsePacketJunkSize int
+ initPacketMagicHeader uint32
+ responsePacketMagicHeader uint32
+ underloadPacketMagicHeader uint32
+ transportPacketMagicHeader uint32
}
// deviceState represents the state of a Device.
@@ -162,7 +180,8 @@ func (device *Device) changeState(want deviceState) (err error) {
err = errDown
}
}
- device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
+ device.log.Verbosef(
+ "Interface state was %s, requested %s, now %s", old, want, device.deviceState())
return
}
@@ -526,7 +545,7 @@ func (device *Device) BindUpdate() error {
// start receiving routines
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
- device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
go device.RoutineReceiveIncoming(batchSize, fn)
@@ -542,3 +561,250 @@ func (device *Device) BindClose() error {
device.net.Unlock()
return err
}
+func (device *Device) isAdvancedSecurityOn() bool {
+ return device.isASecOn.IsSet()
+}
+
+func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
+
+ if tempASecCfg.junkPacketCount == 0 &&
+ tempASecCfg.junkPacketMaxSize == 0 &&
+ tempASecCfg.junkPacketMinSize == 0 &&
+ tempASecCfg.initPacketJunkSize == 0 &&
+ tempASecCfg.responsePacketJunkSize == 0 &&
+ tempASecCfg.initPacketMagicHeader == 0 &&
+ tempASecCfg.responsePacketMagicHeader == 0 &&
+ tempASecCfg.underloadPacketMagicHeader == 0 &&
+ tempASecCfg.transportPacketMagicHeader == 0 {
+ return err
+ }
+
+ isASecOn := false
+ device.aSecMux.Lock()
+ if tempASecCfg.junkPacketCount < 0 {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "JunkPacketCount should be non negative",
+ )
+ }
+ device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
+ if tempASecCfg.junkPacketCount != 0 {
+ isASecOn = true
+ }
+
+ device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
+ if tempASecCfg.junkPacketMinSize != 0 {
+ isASecOn = true
+ }
+
+ if device.aSecCfg.junkPacketCount > 0 &&
+ tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
+
+ tempASecCfg.junkPacketMaxSize++ // to make rand gen work
+ }
+
+ if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize{
+ device.aSecCfg.junkPacketMinSize = 0
+ device.aSecCfg.junkPacketMaxSize = 1
+ if err != nil {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
+ tempASecCfg.junkPacketMaxSize,
+ MaxSegmentSize,
+ err,
+ )
+ } else {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
+ tempASecCfg.junkPacketMaxSize,
+ MaxSegmentSize,
+ )
+ }
+ } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
+ if err != nil {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "maxSize: %d; should be greater than minSize: %d; %w",
+ tempASecCfg.junkPacketMaxSize,
+ tempASecCfg.junkPacketMinSize,
+ err,
+ )
+ } else {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ "maxSize: %d; should be greater than minSize: %d",
+ tempASecCfg.junkPacketMaxSize,
+ tempASecCfg.junkPacketMinSize,
+ )
+ }
+ } else {
+ device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
+ }
+
+ if tempASecCfg.junkPacketMaxSize != 0 {
+ isASecOn = true
+ }
+
+ if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
+ if err != nil {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
+ tempASecCfg.initPacketJunkSize,
+ MaxSegmentSize,
+ err,
+ )
+ } else {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
+ tempASecCfg.initPacketJunkSize,
+ MaxSegmentSize,
+ )
+ }
+ } else {
+ device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
+ }
+
+ if tempASecCfg.initPacketJunkSize != 0 {
+ isASecOn = true
+ }
+
+ if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
+ if err != nil {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
+ tempASecCfg.responsePacketJunkSize,
+ MaxSegmentSize,
+ err,
+ )
+ } else {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
+ tempASecCfg.responsePacketJunkSize,
+ MaxSegmentSize,
+ )
+ }
+ } else {
+ device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
+ }
+
+ if tempASecCfg.responsePacketJunkSize != 0 {
+ isASecOn = true
+ }
+
+ if tempASecCfg.initPacketMagicHeader > 4 {
+ isASecOn = true
+ device.log.Verbosef("UAPI: Updating init_packet_magic_header")
+ device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
+ MessageInitiationType = device.aSecCfg.initPacketMagicHeader
+ } else {
+ device.log.Verbosef("UAPI: Using default init type")
+ MessageInitiationType = 1
+ }
+
+ if tempASecCfg.responsePacketMagicHeader > 4 {
+ isASecOn = true
+ device.log.Verbosef("UAPI: Updating response_packet_magic_header")
+ device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
+ MessageResponseType = device.aSecCfg.responsePacketMagicHeader
+ } else {
+ device.log.Verbosef("UAPI: Using default response type")
+ MessageResponseType = 2
+ }
+
+ if tempASecCfg.underloadPacketMagicHeader > 4 {
+ isASecOn = true
+ device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
+ device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
+ MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
+ } else {
+ device.log.Verbosef("UAPI: Using default underload type")
+ MessageCookieReplyType = 3
+ }
+
+ if tempASecCfg.transportPacketMagicHeader > 4 {
+ isASecOn = true
+ device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
+ device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
+ MessageTransportType = device.aSecCfg.transportPacketMagicHeader
+ } else {
+ device.log.Verbosef("UAPI: Using default transport type")
+ MessageTransportType = 4
+ }
+
+ isSameMap := map[uint32]bool{}
+ isSameMap[MessageInitiationType] = true
+ isSameMap[MessageResponseType] = true
+ isSameMap[MessageCookieReplyType] = true
+ isSameMap[MessageTransportType] = true
+
+ // size will be different if same values
+ if len(isSameMap) != 4 {
+ if err != nil {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
+ MessageInitiationType,
+ MessageResponseType,
+ MessageCookieReplyType,
+ MessageTransportType,
+ err,
+ )
+ } else {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
+ MessageInitiationType,
+ MessageResponseType,
+ MessageCookieReplyType,
+ MessageTransportType,
+ )
+ }
+ }
+
+ newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
+ newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
+
+ if newInitSize == newResponseSize {
+ if err != nil {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `new init size:%d; and new response size:%d; should differ; %w`,
+ newInitSize,
+ newResponseSize,
+ err,
+ )
+ } else {
+ err = ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `new init size:%d; and new response size:%d; should differ`,
+ newInitSize,
+ newResponseSize,
+ )
+ }
+ } else {
+ packetSizeToMsgType = map[int]uint32{
+ newInitSize: MessageInitiationType,
+ newResponseSize: MessageResponseType,
+ MessageCookieReplySize: MessageCookieReplyType,
+ MessageTransportSize: MessageTransportType,
+ }
+
+ msgTypeToJunkSize = map[uint32]int{
+ MessageInitiationType: device.aSecCfg.initPacketJunkSize,
+ MessageResponseType: device.aSecCfg.responsePacketJunkSize,
+ MessageCookieReplyType: 0,
+ MessageTransportType: 0,
+ }
+ }
+
+ device.isASecOn.SetTo(isASecOn)
+ device.aSecMux.Unlock()
+
+ return err
+}
diff --git a/device/device_test.go b/device/device_test.go
index fff172b..afa1dc3 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -20,10 +20,10 @@ import (
"testing"
"time"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/conn/bindtest"
- "golang.zx2c4.com/wireguard/tun"
- "golang.zx2c4.com/wireguard/tun/tuntest"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/conn/bindtest"
+ "github.com/amnezia-vpn/amnezia-wg/tun"
+ "github.com/amnezia-vpn/amnezia-wg/tun/tuntest"
)
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
@@ -91,6 +91,65 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
return
}
+func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
+ var key1, key2 NoisePrivateKey
+ _, err := rand.Read(key1[:])
+ if err != nil {
+ tb.Errorf("unable to generate private key random bytes: %v", err)
+ }
+ _, err = rand.Read(key2[:])
+ if err != nil {
+ tb.Errorf("unable to generate private key random bytes: %v", err)
+ }
+ pub1, pub2 := key1.publicKey(), key2.publicKey()
+
+ cfgs[0] = uapiCfg(
+ "private_key", hex.EncodeToString(key1[:]),
+ "listen_port", "0",
+ "replace_peers", "true",
+ "jc", "5",
+ "jmin", "500",
+ "jmax", "501",
+ "s1", "30",
+ "s2", "40",
+ "h1", "123456",
+ "h2", "67543",
+ "h4", "32345",
+ "h3", "123123",
+ "public_key", hex.EncodeToString(pub2[:]),
+ "protocol_version", "1",
+ "replace_allowed_ips", "true",
+ "allowed_ip", "1.0.0.2/32",
+ )
+ endpointCfgs[0] = uapiCfg(
+ "public_key", hex.EncodeToString(pub2[:]),
+ "endpoint", "127.0.0.1:%d",
+ )
+ cfgs[1] = uapiCfg(
+ "private_key", hex.EncodeToString(key2[:]),
+ "listen_port", "0",
+ "replace_peers", "true",
+ "jc", "5",
+ "jmin", "500",
+ "jmax", "501",
+ "s1", "30",
+ "s2", "40",
+ "h1", "123456",
+ "h2", "67543",
+ "h4", "32345",
+ "h3", "123123",
+ "public_key", hex.EncodeToString(pub1[:]),
+ "protocol_version", "1",
+ "replace_allowed_ips", "true",
+ "allowed_ip", "1.0.0.1/32",
+ )
+ endpointCfgs[1] = uapiCfg(
+ "public_key", hex.EncodeToString(pub1[:]),
+ "endpoint", "127.0.0.1:%d",
+ )
+ return
+}
+
// A testPair is a pair of testPeers.
type testPair [2]testPeer
@@ -115,7 +174,11 @@ func (d SendDirection) String() string {
return "pong"
}
-func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
+func (pair *testPair) Send(
+ tb testing.TB,
+ ping SendDirection,
+ done chan struct{},
+) {
tb.Helper()
p0, p1 := pair[0], pair[1]
if !ping {
@@ -149,8 +212,16 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
}
// genTestPair creates a testPair.
-func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
- cfg, endpointCfg := genConfigs(tb)
+func genTestPair(
+ tb testing.TB,
+ realSocket, withASecurity bool,
+) (pair testPair) {
+ var cfg, endpointCfg [2]string
+ if withASecurity {
+ cfg, endpointCfg = genASecurityConfigs(tb)
+ } else {
+ cfg, endpointCfg = genConfigs(tb)
+ }
var binds [2]conn.Bind
if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
@@ -166,7 +237,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
}
- p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
+ p.dev = NewDevice(p.tun.TUN(),binds[i],NewLogger(level, fmt.Sprintf("dev%d: ", i)))
if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close()
@@ -194,7 +265,18 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t)
- pair := genTestPair(t, true)
+ pair := genTestPair(t, true, false)
+ t.Run("ping 1.0.0.1", func(t *testing.T) {
+ pair.Send(t, Ping, nil)
+ })
+ t.Run("ping 1.0.0.2", func(t *testing.T) {
+ pair.Send(t, Pong, nil)
+ })
+}
+
+func TestTwoDevicePingASecurity(t *testing.T) {
+ goroutineLeakCheck(t)
+ pair := genTestPair(t, true, true)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
})
@@ -209,10 +291,10 @@ func TestUpDown(t *testing.T) {
const otrials = 10
for n := 0; n < otrials; n++ {
- pair := genTestPair(t, false)
+ pair := genTestPair(t, false, false)
for i := range pair {
for k := range pair[i].dev.peers.keyMap {
- pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
+ pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n",hex.EncodeToString(k[:])))
}
}
var wg sync.WaitGroup
@@ -243,7 +325,7 @@ func TestUpDown(t *testing.T) {
// TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) {
- pair := genTestPair(t, true)
+ pair := genTestPair(t, true, false)
done := make(chan struct{})
const warmupIters = 10
@@ -324,7 +406,7 @@ func TestConcurrencySafety(t *testing.T) {
}
func BenchmarkLatency(b *testing.B) {
- pair := genTestPair(b, true)
+ pair := genTestPair(b, true, false)
// Establish a connection.
pair.Send(b, Ping, nil)
@@ -338,7 +420,7 @@ func BenchmarkLatency(b *testing.B) {
}
func BenchmarkThroughput(b *testing.B) {
- pair := genTestPair(b, true)
+ pair := genTestPair(b, true, false)
// Establish a connection.
pair.Send(b, Ping, nil)
@@ -382,7 +464,7 @@ func BenchmarkThroughput(b *testing.B) {
}
func BenchmarkUAPIGet(b *testing.B) {
- pair := genTestPair(b, true)
+ pair := genTestPair(b, true, false)
pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil)
b.ReportAllocs()
@@ -423,29 +505,41 @@ type fakeBindSized struct {
size int
}
-func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+func (b *fakeBindSized) Open(
+ port uint16,
+) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil
}
-func (b *fakeBindSized) Close() error { return nil }
-func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
-func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
+
+func (b *fakeBindSized) Close() error { return nil }
+
+func (b *fakeBindSized) SetMark(mark uint32) error {return nil }
+
+func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
+
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
-func (b *fakeBindSized) BatchSize() int { return b.size }
+
+func (b *fakeBindSized) BatchSize() int { return b.size }
type fakeTUNDeviceSized struct {
size int
}
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
-func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
- return 0, nil
-}
+
+func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { return 0, nil }
+
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
-func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
-func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
-func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
-func (t *fakeTUNDeviceSized) Close() error { return nil }
-func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func TestBatchSize(t *testing.T) {
d := Device{}
diff --git a/device/keypair.go b/device/keypair.go
index e3540d7..73e69af 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -11,7 +11,7 @@ import (
"sync/atomic"
"time"
- "golang.zx2c4.com/wireguard/replay"
+ "github.com/amnezia-vpn/amnezia-wg/replay"
)
/* Due to limitations in Go and /x/crypto there is currently
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index e8f6145..75c1d87 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -15,7 +15,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
- "golang.zx2c4.com/wireguard/tai64n"
+ "github.com/amnezia-vpn/amnezia-wg/tai64n"
)
type handshakeState int
@@ -52,11 +52,11 @@ const (
WGLabelCookie = "cookie--"
)
-const (
- MessageInitiationType = 1
- MessageResponseType = 2
- MessageCookieReplyType = 3
- MessageTransportType = 4
+var (
+ MessageInitiationType uint32 = 1
+ MessageResponseType uint32 = 2
+ MessageCookieReplyType uint32 = 3
+ MessageTransportType uint32 = 4
)
const (
@@ -75,6 +75,10 @@ const (
MessageTransportOffsetContent = 16
)
+var packetSizeToMsgType map[int]uint32
+
+var msgTypeToJunkSize map[uint32]int
+
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now)
@@ -193,10 +197,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
+ device.aSecMux.RLock()
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
+ device.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@@ -250,9 +256,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
+ device.aSecMux.RLock()
if msg.Type != MessageInitiationType {
+ device.aSecMux.RUnlock()
return nil
}
+ device.aSecMux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -367,7 +376,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
+ device.aSecMux.RLock()
msg.Type = MessageResponseType
+ device.aSecMux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@@ -417,9 +428,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
+ device.aSecMux.RLock()
if msg.Type != MessageResponseType {
+ device.aSecMux.RUnlock()
return nil
}
+ device.aSecMux.RUnlock()
// lookup handshake by receiver
diff --git a/device/noise_test.go b/device/noise_test.go
index 2dd5324..2363365 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -10,8 +10,8 @@ import (
"encoding/binary"
"testing"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/tun/tuntest"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/tun/tuntest"
)
func TestCurveWrappers(t *testing.T) {
diff --git a/device/peer.go b/device/peer.go
index 0ac4896..72c7d1a 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -12,7 +12,7 @@ import (
"sync/atomic"
"time"
- "golang.zx2c4.com/wireguard/conn"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
)
type Peer struct {
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index 3d80ead..4adb687 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -5,7 +5,7 @@
package device
-import "golang.zx2c4.com/wireguard/conn"
+import "github.com/amnezia-vpn/amnezia-wg/conn"
/* Reduce memory consumption for Android */
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index ea763d0..4ee2966 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -7,7 +7,7 @@
package device
-import "golang.zx2c4.com/wireguard/conn"
+import "github.com/amnezia-vpn/amnezia-wg/conn"
const (
QueueStagedSize = conn.IdealBatchSize
diff --git a/device/receive.go b/device/receive.go
index e24d29f..ca71539 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -13,10 +13,10 @@ import (
"sync"
"time"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
- "golang.zx2c4.com/wireguard/conn"
)
type QueueHandshakeElement struct {
@@ -66,7 +66,10 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
+func (device *Device) RoutineReceiveIncoming(
+ maxBatchSize int,
+ recv conn.ReceiveFunc,
+) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@@ -122,6 +125,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
}
deathSpiral = 0
+ device.aSecMux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@@ -131,8 +135,29 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
// check size of packet
packet := bufsArrs[i][:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
-
+ var msgType uint32
+ if device.isAdvancedSecurityOn() {
+ if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
+ junkSize := msgTypeToJunkSize[assumedMsgType]
+ // transport size can align with other header types;
+ // making sure we have the right msgType
+ msgType = binary.LittleEndian.Uint32(packet[junkSize:junkSize+4])
+ if msgType == assumedMsgType {
+ packet = packet[junkSize:]
+ } else {
+ device.log.Verbosef("Transport packet lined up with another msg type")
+ msgType = binary.LittleEndian.Uint32(packet[:4])
+ }
+ } else {
+ msgType = binary.LittleEndian.Uint32(packet[:4])
+ if msgType != MessageTransportType {
+ device.log.Verbosef("ASec: Received message with unknown type")
+ continue
+ }
+ }
+ } else {
+ msgType = binary.LittleEndian.Uint32(packet[:4])
+ }
switch msgType {
// check if transport
@@ -217,6 +242,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
default:
}
}
+ device.aSecMux.RUnlock()
for peer, elems := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elems
@@ -275,6 +301,8 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
+ device.aSecMux.RLock()
+
// handle cookie fields and ratelimiting
switch elem.msgType {
@@ -302,9 +330,14 @@ func (device *Device) RoutineHandshake(id int) {
// consume reply
if peer := entry.peer; peer.isRunning.Load() {
- device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
+ device.log.Verbosef(
+ "Receiving cookie response from %s",
+ elem.endpoint.DstToString(),
+ )
if !peer.cookieGenerator.ConsumeReply(&reply) {
- device.log.Verbosef("Could not decrypt invalid cookie response")
+ device.log.Verbosef(
+ "Could not decrypt invalid cookie response",
+ )
}
}
@@ -346,9 +379,7 @@ func (device *Device) RoutineHandshake(id int) {
switch elem.msgType {
case MessageInitiationType:
-
// unmarshal
-
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
@@ -358,7 +389,6 @@ func (device *Device) RoutineHandshake(id int) {
}
// consume initiation
-
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
@@ -423,6 +453,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
+ device.aSecMux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
@@ -503,11 +534,17 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
}
default:
- device.log.Verbosef("Packet with invalid IP version from %v", peer)
+ device.log.Verbosef(
+ "Packet with invalid IP version from %v",
+ peer,
+ )
continue
}
- bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
+ bufs = append(
+ bufs,
+ elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
+ )
}
if len(bufs) > 0 {
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
diff --git a/device/send.go b/device/send.go
index d22bf26..6f70d54 100644
--- a/device/send.go
+++ b/device/send.go
@@ -9,15 +9,16 @@ import (
"bytes"
"encoding/binary"
"errors"
+ "math/rand"
"net"
"os"
"sync"
"time"
+ "github.com/amnezia-vpn/amnezia-wg/tun"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
- "golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@@ -119,17 +120,44 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
return err
}
-
+ var sendBuffer [][]byte
+ // so only packet processed for cookie generation
+ var junkedHeader []byte
+ if peer.device.isAdvancedSecurityOn() {
+ peer.device.aSecMux.RLock()
+ junks, err := peer.createJunkPackets()
+ if err != nil {
+ peer.device.aSecMux.RUnlock()
+ peer.device.log.Errorf("%v - %v", peer, err)
+ return err
+ }
+ sendBuffer = append(sendBuffer, junks...)
+ if peer.device.aSecCfg.initPacketJunkSize != 0 {
+ buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
+ writer := bytes.NewBuffer(buf[:0])
+ err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
+ if err != nil {
+ peer.device.aSecMux.RUnlock()
+ peer.device.log.Errorf("%v - %v", peer, err)
+ return err
+ }
+ junkedHeader = writer.Bytes()
+ }
+ peer.device.aSecMux.RUnlock()
+ }
var buf [MessageInitiationSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
+ junkedHeader = append(junkedHeader, packet...)
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
-
- err = peer.SendBuffers([][]byte{packet})
+
+ sendBuffer = append(sendBuffer, junkedHeader)
+
+ err = peer.SendBuffers(sendBuffer)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@@ -150,12 +178,29 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
return err
}
-
+ var junkedHeader []byte
+ if peer.device.isAdvancedSecurityOn() {
+ peer.device.aSecMux.RLock()
+ if peer.device.aSecCfg.responsePacketJunkSize != 0 {
+ buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
+ writer := bytes.NewBuffer(buf[:0])
+ err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
+ if err != nil {
+ peer.device.aSecMux.RUnlock()
+ peer.device.log.Errorf("%v - %v", peer, err)
+ return err
+ }
+ junkedHeader = writer.Bytes()
+ }
+ peer.device.aSecMux.RUnlock()
+ }
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
+
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
+ junkedHeader = append(junkedHeader, packet...)
err = peer.BeginSymmetricSession()
if err != nil {
@@ -168,18 +213,24 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketSent()
// TODO: allocation could be avoided
- err = peer.SendBuffers([][]byte{packet})
+ err = peer.SendBuffers([][]byte{junkedHeader})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
return err
}
-func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
+func (device *Device) SendHandshakeCookie(
+ initiatingElem *QueueHandshakeElement,
+) error {
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
- reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
+ reply, err := device.cookieChecker.CreateReply(
+ initiatingElem.packet,
+ sender,
+ initiatingElem.endpoint.DstToBytes(),
+ )
if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err)
return err
@@ -404,6 +455,31 @@ top:
}
}
+func (peer *Peer) createJunkPackets() ([][]byte, error) {
+ if peer.device.aSecCfg.junkPacketCount == 0 {
+ return nil, nil
+ }
+
+ junks := make([][]byte, 0, peer.device.aSecCfg.junkPacketCount)
+ for i := 0; i < peer.device.aSecCfg.junkPacketCount; i++ {
+ packetSize := rand.Intn(
+ peer.device.aSecCfg.junkPacketMaxSize-peer.device.aSecCfg.junkPacketMinSize,
+ ) + peer.device.aSecCfg.junkPacketMinSize
+
+ junk, err := randomJunkWithSize(packetSize)
+ if err != nil {
+ peer.device.log.Errorf(
+ "%v - Failed to create junk packet: %v",
+ peer,
+ err,
+ )
+ return nil, err
+ }
+ junks = append(junks, junk)
+ }
+ return junks, nil
+}
+
func (peer *Peer) FlushStagedPackets() {
for {
select {
@@ -459,18 +535,16 @@ func (device *Device) RoutineEncryption(id int) {
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
- paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
+ paddingSize := calculatePaddingSize(
+ len(elem.packet),
+ int(device.tun.mtu.Load()),
+ )
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
- elem.packet = elem.keypair.send.Seal(
- header,
- nonce[:],
- elem.packet,
- nil,
- )
+ elem.packet = elem.keypair.send.Seal(header, nonce[:], elem.packet, nil)
elem.Unlock()
}
}
diff --git a/device/sticky_default.go b/device/sticky_default.go
index 1038256..940702c 100644
--- a/device/sticky_default.go
+++ b/device/sticky_default.go
@@ -3,8 +3,8 @@
package device
import (
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/rwcancel"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index f9230f8..5c17480 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -20,8 +20,8 @@ import (
"golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/rwcancel"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
diff --git a/device/tun.go b/device/tun.go
index 2a2ace9..efc543d 100644
--- a/device/tun.go
+++ b/device/tun.go
@@ -8,7 +8,7 @@ package device
import (
"fmt"
- "golang.zx2c4.com/wireguard/tun"
+ "github.com/amnezia-vpn/amnezia-wg/tun"
)
const DefaultMTU = 1420
diff --git a/device/uapi.go b/device/uapi.go
index 617dcd3..bfd005a 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -18,7 +18,7 @@ import (
"sync"
"time"
- "golang.zx2c4.com/wireguard/ipc"
+ "github.com/amnezia-vpn/amnezia-wg/ipc"
)
type IPCError struct {
@@ -97,6 +97,36 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
+ if device.isAdvancedSecurityOn() {
+ if device.aSecCfg.junkPacketCount != 0 {
+ sendf("jc=%d", device.aSecCfg.junkPacketCount)
+ }
+ if device.aSecCfg.junkPacketMinSize != 0 {
+ sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
+ }
+ if device.aSecCfg.junkPacketMaxSize != 0 {
+ sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
+ }
+ if device.aSecCfg.initPacketJunkSize != 0 {
+ sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
+ }
+ if device.aSecCfg.responsePacketJunkSize != 0 {
+ sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
+ }
+ if device.aSecCfg.initPacketMagicHeader != 0 {
+ sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
+ }
+ if device.aSecCfg.responsePacketMagicHeader != 0 {
+ sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
+ }
+ if device.aSecCfg.underloadPacketMagicHeader != 0 {
+ sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
+ }
+ if device.aSecCfg.transportPacketMagicHeader != 0 {
+ sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
+ }
+ }
+
for _, peer := range device.peers.keyMap {
// Serialize peer state.
// Do the work in an anonymous function so that we can use defer.
@@ -121,10 +151,13 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
- device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
- sendf("allowed_ip=%s", prefix.String())
- return true
- })
+ device.allowedips.EntriesForPeer(
+ peer,
+ func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
+ return true
+ },
+ )
}()
}
}()
@@ -152,17 +185,26 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
+ tempASecCfg := aSecCfgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
+ err := device.handlePostConfig(&tempASecCfg)
+ if err != nil {
+ return err
+ }
peer.handlePostConfig()
return nil
}
key, value, ok := strings.Cut(line, "=")
if !ok {
- return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
+ return ipcErrorf(
+ ipc.IpcErrorProtocol,
+ "failed to parse line %q",
+ line,
+ )
}
if key == "public_key" {
@@ -180,7 +222,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
- err = device.handleDeviceLine(key, value)
+ err = device.handleDeviceLine(key, value, &tempASecCfg)
} else {
err = device.handlePeerLine(peer, key, value)
}
@@ -188,6 +230,10 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
+ err = device.handlePostConfig(&tempASecCfg)
+ if err != nil {
+ return err
+ }
peer.handlePostConfig()
if err := scanner.Err(); err != nil {
@@ -196,7 +242,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
-func (device *Device) handleDeviceLine(key, value string) error {
+func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@@ -242,8 +288,75 @@ func (device *Device) handleDeviceLine(key, value string) error {
device.log.Verbosef("UAPI: Removing all peers")
device.RemoveAllPeers()
+ case "jc":
+ junkPacketCount, err := strconv.Atoi(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
+ }
+ device.log.Verbosef("UAPI: Updating junk_packet_count")
+ tempASecCfg.junkPacketCount = junkPacketCount
+
+ case "jmin":
+ junkPacketMinSize, err := strconv.Atoi(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
+ }
+ device.log.Verbosef("UAPI: Updating junk_packet_min_size")
+ tempASecCfg.junkPacketMinSize = junkPacketMinSize
+
+ case "jmax":
+ junkPacketMaxSize, err := strconv.Atoi(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
+ }
+ device.log.Verbosef("UAPI: Updating junk_packet_max_size")
+ tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
+
+ case "s1":
+ initPacketJunkSize, err := strconv.Atoi(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
+ }
+ device.log.Verbosef("UAPI: Updating init_packet_junk_size")
+ tempASecCfg.initPacketJunkSize = initPacketJunkSize
+
+ case "s2":
+ responsePacketJunkSize, err := strconv.Atoi(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
+ }
+ device.log.Verbosef("UAPI: Updating response_packet_junk_size")
+ tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
+
+ case "h1":
+ initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
+ }
+ tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
+
+ case "h2":
+ responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
+ }
+ tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
+
+ case "h3":
+ underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
+ }
+ tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
+
+ case "h4":
+ transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
+ }
+ tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
default:
- return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v",key)
}
return nil
@@ -262,7 +375,8 @@ func (peer *ipcSetPeer) handlePostConfig() {
return
}
if peer.created {
- peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
+ peer.disableRoaming = peer.device.net.brokenRoaming &&
+ peer.endpoint != nil
}
if peer.device.isUp() {
peer.Start()
@@ -273,7 +387,10 @@ func (peer *ipcSetPeer) handlePostConfig() {
}
}
-func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
+func (device *Device) handlePublicKeyLine(
+ peer *ipcSetPeer,
+ value string,
+) error {
// Load/create the peer we are configuring.
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
@@ -303,7 +420,10 @@ func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error
return nil
}
-func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
+func (device *Device) handlePeerLine(
+ peer *ipcSetPeer,
+ key, value string,
+) error {
switch key {
case "update_only":
// allow disabling of creation
@@ -343,7 +463,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
endpoint, err := device.net.bind.ParseEndpoint(value)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
peer.Lock()
defer peer.Unlock()
diff --git a/device/util.go b/device/util.go
new file mode 100644
index 0000000..aab8ab7
--- /dev/null
+++ b/device/util.go
@@ -0,0 +1,25 @@
+package device
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "fmt"
+)
+
+func appendJunk(writer *bytes.Buffer, size int) error {
+ headerJunk, err := randomJunkWithSize(size)
+ if err != nil {
+ return fmt.Errorf("failed to create header junk: %v", err)
+ }
+ _, err = writer.Write(headerJunk)
+ if err != nil {
+ return fmt.Errorf("failed to write header junk: %v", err)
+ }
+ return nil
+}
+
+func randomJunkWithSize(size int) ([]byte, error) {
+ junk := make([]byte, size)
+ _, err := crand.Read(junk)
+ return junk, err
+}
diff --git a/device/util_test.go b/device/util_test.go
new file mode 100644
index 0000000..c061eef
--- /dev/null
+++ b/device/util_test.go
@@ -0,0 +1,27 @@
+package device
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+)
+
+func Test_randomJunktWithSize(t *testing.T) {
+ junk, err := randomJunkWithSize(30)
+ fmt.Println(string(junk), len(junk), err)
+}
+
+func Test_appendJunk(t *testing.T) {
+ t.Run("", func(t *testing.T) {
+ s := "apple"
+ buffer := bytes.NewBuffer([]byte(s))
+ err := appendJunk(buffer, 30)
+ if err != nil &&
+ buffer.Len() != len(s)+30 {
+ t.Errorf("appendWithJunk() size don't match")
+ }
+ read := make([]byte, 50)
+ buffer.Read(read)
+ fmt.Println(string(read))
+ })
+}
diff --git a/go.mod b/go.mod
index c04e1bb..4a3c9c6 100644
--- a/go.mod
+++ b/go.mod
@@ -1,8 +1,9 @@
-module golang.zx2c4.com/wireguard
+module github.com/amnezia-vpn/amnezia-wg
go 1.20
require (
+ github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.6.0
golang.org/x/net v0.7.0
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
diff --git a/go.sum b/go.sum
index cfeaee6..3707808 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,7 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
+github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
+github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
diff --git a/ipc/namedpipe/namedpipe_test.go b/ipc/namedpipe/namedpipe_test.go
index 998453b..d4799e1 100644
--- a/ipc/namedpipe/namedpipe_test.go
+++ b/ipc/namedpipe/namedpipe_test.go
@@ -20,8 +20,8 @@ import (
"testing"
"time"
+ "github.com/amnezia-vpn/amnezia-wg/ipc/namedpipe"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/ipc/namedpipe"
)
func randomPipePath() string {
diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go
index 1562a18..721c404 100644
--- a/ipc/uapi_linux.go
+++ b/ipc/uapi_linux.go
@@ -9,8 +9,8 @@ import (
"net"
"os"
+ "github.com/amnezia-vpn/amnezia-wg/rwcancel"
"golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/rwcancel"
)
type UAPIListener struct {
diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go
index aa023c9..97a4123 100644
--- a/ipc/uapi_windows.go
+++ b/ipc/uapi_windows.go
@@ -8,8 +8,8 @@ package ipc
import (
"net"
+ "github.com/amnezia-vpn/amnezia-wg/ipc/namedpipe"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/ipc/namedpipe"
)
// TODO: replace these with actual standard windows error numbers from the win package
diff --git a/main.go b/main.go
index e016116..ea7ef4e 100644
--- a/main.go
+++ b/main.go
@@ -14,11 +14,11 @@ import (
"runtime"
"strconv"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/device"
+ "github.com/amnezia-vpn/amnezia-wg/ipc"
+ "github.com/amnezia-vpn/amnezia-wg/tun"
"golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/ipc"
- "golang.zx2c4.com/wireguard/tun"
)
const (
diff --git a/main_windows.go b/main_windows.go
index a4dc46f..d00b146 100644
--- a/main_windows.go
+++ b/main_windows.go
@@ -12,11 +12,11 @@ import (
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/ipc"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/device"
+ "github.com/amnezia-vpn/amnezia-wg/ipc"
- "golang.zx2c4.com/wireguard/tun"
+ "github.com/amnezia-vpn/amnezia-wg/tun"
)
const (
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go
index ccd32ed..ed40904 100644
--- a/tun/netstack/examples/http_client.go
+++ b/tun/netstack/examples/http_client.go
@@ -13,9 +13,9 @@ import (
"net/http"
"net/netip"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun/netstack"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/device"
+ "github.com/amnezia-vpn/amnezia-wg/tun/netstack"
)
func main() {
diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go
index f5b7a8f..d5e7094 100644
--- a/tun/netstack/examples/http_server.go
+++ b/tun/netstack/examples/http_server.go
@@ -14,9 +14,9 @@ import (
"net/http"
"net/netip"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun/netstack"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/device"
+ "github.com/amnezia-vpn/amnezia-wg/tun/netstack"
)
func main() {
diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go
index 2eef0fb..9f917db 100644
--- a/tun/netstack/examples/ping_client.go
+++ b/tun/netstack/examples/ping_client.go
@@ -17,9 +17,9 @@ import (
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun/netstack"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/device"
+ "github.com/amnezia-vpn/amnezia-wg/tun/netstack"
)
func main() {
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index 596cfcd..f5a40f5 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -22,7 +22,7 @@ import (
"syscall"
"time"
- "golang.zx2c4.com/wireguard/tun"
+ "github.com/amnezia-vpn/amnezia-wg/tun"
"golang.org/x/net/dns/dnsmessage"
"gvisor.dev/gvisor/pkg/bufferv2"
diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go
index 39a7180..a43f0df 100644
--- a/tun/tcp_offload_linux.go
+++ b/tun/tcp_offload_linux.go
@@ -12,8 +12,8 @@ import (
"io"
"unsafe"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
"golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/conn"
)
const tcpFlagsOffset = 13
diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go
index 9160e18..57c6a09 100644
--- a/tun/tcp_offload_linux_test.go
+++ b/tun/tcp_offload_linux_test.go
@@ -9,8 +9,8 @@ import (
"net/netip"
"testing"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
"golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/conn"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 12cd49f..31c1513 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -17,9 +17,9 @@ import (
"time"
"unsafe"
+ "github.com/amnezia-vpn/amnezia-wg/conn"
+ "github.com/amnezia-vpn/amnezia-wg/rwcancel"
"golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/rwcancel"
)
const (
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
index d07e860..7068d9b 100644
--- a/tun/tuntest/tuntest.go
+++ b/tun/tuntest/tuntest.go
@@ -11,7 +11,7 @@ import (
"net/netip"
"os"
- "golang.zx2c4.com/wireguard/tun"
+ "github.com/amnezia-vpn/amnezia-wg/tun"
)
func Ping(dst, src netip.Addr) []byte {