From 15d7259cd45f41d50bc5838dc002698362f29a2d Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Tue, 15 Jul 2025 19:05:22 +0200
Subject: [PATCH] feat: ranged magic headers
---
device/awg/awg.go | 160 ++++-------
device/awg/junk_creator.go | 64 ++---
device/awg/junk_creator_test.go | 77 ++---
device/awg/magic_header.go | 93 ++++++
device/awg/magic_header_test.go | 488 ++++++++++++++++++++++++++++++++
device/awg/prng.go | 50 ++++
device/awg/tag_parser.go | 2 +-
device/awg/tag_parser_test.go | 2 +-
device/cookie.go | 3 +-
device/cookie_test.go | 2 +-
device/device.go | 338 ++++++++++++++--------
device/device_test.go | 16 +-
device/noise-protocol.go | 37 ++-
device/receive.go | 48 +---
device/send.go | 31 +-
device/uapi.go | 116 ++++----
go.mod | 2 +-
go.sum | 14 +-
18 files changed, 1084 insertions(+), 459 deletions(-)
create mode 100644 device/awg/magic_header.go
create mode 100644 device/awg/magic_header_test.go
create mode 100644 device/awg/prng.go
diff --git a/device/awg/awg.go b/device/awg/awg.go
index fd5a96d..888a42e 100644
--- a/device/awg/awg.go
+++ b/device/awg/awg.go
@@ -3,142 +3,88 @@ package awg
import (
"bytes"
"fmt"
- "slices"
- "strconv"
- "strings"
"sync"
"github.com/tevino/abool"
)
-type aSecCfgType struct {
- IsSet bool
- JunkPacketCount int
- JunkPacketMinSize int
- JunkPacketMaxSize int
- InitHeaderJunkSize int
- ResponseHeaderJunkSize int
- CookieReplyHeaderJunkSize int
- TransportHeaderJunkSize int
- InitPacketMagicHeader uint32
- ResponsePacketMagicHeader uint32
- UnderloadPacketMagicHeader uint32
- TransportPacketMagicHeader uint32
- // InitPacketMagicHeader Limit
- // ResponsePacketMagicHeader Limit
- // UnderloadPacketMagicHeader Limit
- // TransportPacketMagicHeader Limit
-}
+type Cfg struct {
+ IsSet bool
+ JunkPacketCount int
+ JunkPacketMinSize int
+ JunkPacketMaxSize int
+ InitHeaderJunkSize int
+ ResponseHeaderJunkSize int
+ CookieReplyHeaderJunkSize int
+ TransportHeaderJunkSize int
-type Limit struct {
- Min uint32
- Max uint32
- HeaderType uint32
-}
-
-func NewLimit(min, max, headerType uint32) (Limit, error) {
- if min > max {
- return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
- }
-
- return Limit{
- Min: min,
- Max: max,
- HeaderType: headerType,
- }, nil
-}
-
-func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) {
- // tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType)
- // var min, max, headerType uint32
- // _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType)
- // if err != nil {
- // return Limit{}, fmt.Errorf("invalid magic header format: %s", value)
- // }
-
- limits := strings.Split(value, "-")
- if len(limits) != 2 {
- return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value)
- }
-
- min, err := strconv.ParseUint(limits[0], 10, 32)
- if err != nil {
- return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err)
- }
-
- max, err := strconv.ParseUint(limits[1], 10, 32)
- if err != nil {
- return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err)
- }
-
- limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType)
- if err != nil {
- return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err)
- }
-
- return limit, nil
-}
-
-type Limits []Limit
-
-func NewLimits(limits []Limit) Limits {
- slices.SortFunc(limits, func(a, b Limit) int {
- if a.Min < b.Min {
- return -1
- } else if a.Min > b.Min {
- return 1
- }
- return 0
- })
-
- return Limits(limits)
+ MagicHeaders MagicHeaders
}
type Protocol struct {
- IsASecOn abool.AtomicBool
+ IsOn abool.AtomicBool
// TODO: revision the need of the mutex
- ASecMux sync.RWMutex
- ASecCfg aSecCfgType
- JunkCreator junkCreator
+ Mux sync.RWMutex
+ Cfg Cfg
+ JunkCreator JunkCreator
HandshakeHandler SpecialHandshakeHandler
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
- return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize)
+ protocol.Mux.RLock()
+ defer protocol.Mux.RUnlock()
+
+ return protocol.createHeaderJunk(protocol.Cfg.InitHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
- return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize)
+ protocol.Mux.RLock()
+ defer protocol.Mux.RUnlock()
+
+ return protocol.createHeaderJunk(protocol.Cfg.ResponseHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
- return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize)
+ protocol.Mux.RLock()
+ defer protocol.Mux.RUnlock()
+
+ return protocol.createHeaderJunk(protocol.Cfg.CookieReplyHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
- return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize)
+ protocol.Mux.RLock()
+ defer protocol.Mux.RUnlock()
+
+ return protocol.createHeaderJunk(protocol.Cfg.TransportHeaderJunkSize, packetSize)
}
-func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) {
- extraSize := 0
- if len(optExtraSize) == 1 {
- extraSize = optExtraSize[0]
+func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
+ if junkSize == 0 {
+ return nil, nil
}
- var junk []byte
- protocol.ASecMux.RLock()
- if junkSize != 0 {
- buf := make([]byte, 0, junkSize+extraSize)
- writer := bytes.NewBuffer(buf[:0])
- err := protocol.JunkCreator.AppendJunk(writer, junkSize)
- if err != nil {
- protocol.ASecMux.RUnlock()
- return nil, err
+ buf := make([]byte, 0, junkSize+extraSize)
+ writer := bytes.NewBuffer(buf[:0])
+
+ err := protocol.JunkCreator.AppendJunk(writer, junkSize)
+ if err != nil {
+ return nil, fmt.Errorf("append junk: %w", err)
+ }
+
+ return writer.Bytes(), nil
+}
+
+func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
+ for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
+ if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
+ return magicHeader.Min, nil
}
- junk = writer.Bytes()
}
- protocol.ASecMux.RUnlock()
- return junk, nil
+ return 0, fmt.Errorf("no header for value: %d", msgType)
+}
+
+func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
+ return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
}
diff --git a/device/awg/junk_creator.go b/device/awg/junk_creator.go
index 91fd253..8ba2918 100644
--- a/device/awg/junk_creator.go
+++ b/device/awg/junk_creator.go
@@ -2,69 +2,49 @@ package awg
import (
"bytes"
- crand "crypto/rand"
"fmt"
- v2 "math/rand/v2"
)
-type junkCreator struct {
- aSecCfg aSecCfgType
- cha8Rand *v2.ChaCha8
+type JunkCreator struct {
+ cfg Cfg
+ randomGenerator PRNG[int]
}
// TODO: refactor param to only pass the junk related params
-func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
- buf := make([]byte, 32)
- _, err := crand.Read(buf)
- if err != nil {
- return junkCreator{}, err
- }
- return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
+func NewJunkCreator(cfg Cfg) JunkCreator {
+ return JunkCreator{cfg: cfg, randomGenerator: NewPRNG[int]()}
}
-// Should be called with aSecMux RLocked
-func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
- if jc.aSecCfg.JunkPacketCount == 0 {
- return nil
+// Should be called with awg mux RLocked
+func (jc *JunkCreator) CreateJunkPackets(junks *[][]byte) {
+ if jc.cfg.JunkPacketCount == 0 {
+ return
}
- for range jc.aSecCfg.JunkPacketCount {
+ for range jc.cfg.JunkPacketCount {
packetSize := jc.randomPacketSize()
- junk, err := jc.randomJunkWithSize(packetSize)
- if err != nil {
- return fmt.Errorf("create junk packet: %v", err)
- }
+ junk := jc.randomJunkWithSize(packetSize)
*junks = append(*junks, junk)
}
- return nil
+ return
}
-// Should be called with aSecMux RLocked
-func (jc *junkCreator) randomPacketSize() int {
- return int(
- jc.cha8Rand.Uint64()%uint64(
- jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
- ),
- ) + jc.aSecCfg.JunkPacketMinSize
+// Should be called with awg mux RLocked
+func (jc *JunkCreator) randomPacketSize() int {
+ return jc.randomGenerator.RandomSizeInRange(jc.cfg.JunkPacketMinSize, jc.cfg.JunkPacketMaxSize)
}
-// Should be called with aSecMux RLocked
-func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
- headerJunk, err := jc.randomJunkWithSize(size)
- if err != nil {
- return fmt.Errorf("create header junk: %v", err)
- }
- _, err = writer.Write(headerJunk)
+// Should be called with awg mux RLocked
+func (jc *JunkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
+ headerJunk := jc.randomJunkWithSize(size)
+ _, err := writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("write header junk: %v", err)
}
return nil
}
-// Should be called with aSecMux RLocked
-func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
- // TODO: use a memory pool to allocate
- junk := make([]byte, size)
- _, err := jc.cha8Rand.Read(junk)
- return junk, err
+// Should be called with awg mux RLocked
+func (jc *JunkCreator) randomJunkWithSize(size int) []byte {
+ return jc.randomGenerator.ReadSize(size)
}
diff --git a/device/awg/junk_creator_test.go b/device/awg/junk_creator_test.go
index 424f104..33532e4 100644
--- a/device/awg/junk_creator_test.go
+++ b/device/awg/junk_creator_test.go
@@ -6,43 +6,29 @@ import (
"testing"
)
-func setUpJunkCreator(t *testing.T) (junkCreator, error) {
- jc, err := NewJunkCreator(aSecCfgType{
- IsSet: true,
- JunkPacketCount: 5,
- JunkPacketMinSize: 500,
- JunkPacketMaxSize: 1000,
- InitHeaderJunkSize: 30,
- ResponseHeaderJunkSize: 40,
- InitPacketMagicHeader: 123456,
- ResponsePacketMagicHeader: 67543,
- UnderloadPacketMagicHeader: 32345,
- TransportPacketMagicHeader: 123123,
+func setUpJunkCreator() JunkCreator {
+ jc := NewJunkCreator(Cfg{
+ IsSet: true,
+ JunkPacketCount: 5,
+ JunkPacketMinSize: 500,
+ JunkPacketMaxSize: 1000,
+ InitHeaderJunkSize: 30,
+ ResponseHeaderJunkSize: 40,
+ // TODO
+ // InitPacketMagicHeader: 123456,
+ // ResponsePacketMagicHeader: 67543,
+ // UnderloadPacketMagicHeader: 32345,
+ // TransportPacketMagicHeader: 123123,
})
- if err != nil {
- t.Errorf("failed to create junk creator %v", err)
- return junkCreator{}, err
- }
-
- return jc, nil
+ return jc
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
- jc, err := setUpJunkCreator(t)
- if err != nil {
- return
- }
+ jc := setUpJunkCreator()
t.Run("valid", func(t *testing.T) {
- got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
- err := jc.CreateJunkPackets(&got)
- if err != nil {
- t.Errorf(
- "junkCreator.createJunkPackets() = %v; failed",
- err,
- )
- return
- }
+ got := make([][]byte, 0, jc.cfg.JunkPacketCount)
+ jc.CreateJunkPackets(&got)
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
@@ -61,34 +47,28 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("valid", func(t *testing.T) {
- jc, err := setUpJunkCreator(t)
- if err != nil {
- return
- }
- r1, _ := jc.randomJunkWithSize(10)
- r2, _ := jc.randomJunkWithSize(10)
+ jc := setUpJunkCreator()
+ r1 := jc.randomJunkWithSize(10)
+ r2 := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
- t.Errorf("same junks %v", err)
+ t.Errorf("same junks")
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
- jc, err := setUpJunkCreator(t)
- if err != nil {
- return
- }
+ jc := setUpJunkCreator()
for range [30]struct{}{} {
t.Run("valid", func(t *testing.T) {
- if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
- got > jc.aSecCfg.JunkPacketMaxSize {
+ if got := jc.randomPacketSize(); jc.cfg.JunkPacketMinSize > got ||
+ got > jc.cfg.JunkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
- jc.aSecCfg.JunkPacketMinSize,
- jc.aSecCfg.JunkPacketMaxSize,
+ jc.cfg.JunkPacketMinSize,
+ jc.cfg.JunkPacketMaxSize,
)
}
})
@@ -96,10 +76,7 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
}
func Test_junkCreator_appendJunk(t *testing.T) {
- jc, err := setUpJunkCreator(t)
- if err != nil {
- return
- }
+ jc := setUpJunkCreator()
t.Run("valid", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
diff --git a/device/awg/magic_header.go b/device/awg/magic_header.go
new file mode 100644
index 0000000..b8bfb2f
--- /dev/null
+++ b/device/awg/magic_header.go
@@ -0,0 +1,93 @@
+package awg
+
+import (
+ "cmp"
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+)
+
+type MagicHeader struct {
+ Min uint32
+ Max uint32
+}
+
+func NewMagicHeaderSameValue(value uint32) MagicHeader {
+ return MagicHeader{Min: value, Max: value}
+}
+
+func NewMagicHeader(min, max uint32) (MagicHeader, error) {
+ if min > max {
+ return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
+ }
+
+ return MagicHeader{Min: min, Max: max}, nil
+}
+
+func ParseMagicHeader(key, value string) (MagicHeader, error) {
+ splitLimits := strings.Split(value, "-")
+ if len(splitLimits) != 2 {
+ // if there is no hyphen, we treat it as single magic header value
+ magicHeader, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
+ }
+
+ return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
+ } else if len(splitLimits[0]) == 0 || len(splitLimits[1]) == 0 {
+ return MagicHeader{}, fmt.Errorf("invalid value for key: %s; value: %s; expected format: min-max", key, value)
+ }
+
+ min, err := strconv.ParseUint(splitLimits[0], 10, 32)
+ if err != nil {
+ return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err)
+ }
+
+ max, err := strconv.ParseUint(splitLimits[1], 10, 32)
+ if err != nil {
+ return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
+ }
+
+ magicHeader, err := NewMagicHeader(uint32(min), uint32(max))
+ if err != nil {
+ return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
+ }
+
+ return magicHeader, nil
+}
+
+type MagicHeaders struct {
+ Values []MagicHeader
+ randomGenerator RandomNumberGenerator[uint32]
+}
+
+func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
+ if len(headerValues) != 4 {
+ return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", headerValues)
+ }
+
+ sortedMagicHeaders := slices.SortedFunc(slices.Values(headerValues), func(lhs MagicHeader, rhs MagicHeader) int {
+ return cmp.Compare(lhs.Min, rhs.Min)
+ })
+
+ for i := range 3 {
+ if sortedMagicHeaders[i].Max >= sortedMagicHeaders[i+1].Min {
+ return MagicHeaders{}, fmt.Errorf(
+ "magic headers shouldn't overlap; %v > %v",
+ sortedMagicHeaders[i].Max,
+ sortedMagicHeaders[i+1].Min,
+ )
+ }
+ }
+
+ return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
+}
+
+func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
+ if defaultMsgType == 0 || defaultMsgType > 4 {
+ return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
+ }
+
+ return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
+}
diff --git a/device/awg/magic_header_test.go b/device/awg/magic_header_test.go
new file mode 100644
index 0000000..72a823e
--- /dev/null
+++ b/device/awg/magic_header_test.go
@@ -0,0 +1,488 @@
+package awg
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewMagicHeaderSameValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value uint32
+ expected MagicHeader
+ }{
+ {
+ name: "zero value",
+ value: 0,
+ expected: MagicHeader{Min: 0, Max: 0},
+ },
+ {
+ name: "small value",
+ value: 1,
+ expected: MagicHeader{Min: 1, Max: 1},
+ },
+ {
+ name: "large value",
+ value: 4294967295, // max uint32
+ expected: MagicHeader{Min: 4294967295, Max: 4294967295},
+ },
+ {
+ name: "medium value",
+ value: 1000,
+ expected: MagicHeader{Min: 1000, Max: 1000},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result := NewMagicHeaderSameValue(tt.value)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestNewMagicHeader(t *testing.T) {
+ tests := []struct {
+ name string
+ min uint32
+ max uint32
+ expected MagicHeader
+ errorMsg string
+ }{
+ {
+ name: "valid range",
+ min: 1,
+ max: 10,
+ expected: MagicHeader{Min: 1, Max: 10},
+ },
+ {
+ name: "equal values",
+ min: 5,
+ max: 5,
+ expected: MagicHeader{Min: 5, Max: 5},
+ },
+ {
+ name: "zero range",
+ min: 0,
+ max: 0,
+ expected: MagicHeader{Min: 0, Max: 0},
+ },
+ {
+ name: "max uint32 range",
+ min: 4294967294,
+ max: 4294967295,
+ expected: MagicHeader{Min: 4294967294, Max: 4294967295},
+ },
+ {
+ name: "min greater than max",
+ min: 10,
+ max: 5,
+ expected: MagicHeader{},
+ errorMsg: "min (10) cannot be greater than max (5)",
+ },
+ {
+ name: "large min greater than max",
+ min: 4294967295,
+ max: 1,
+ expected: MagicHeader{},
+ errorMsg: "min (4294967295) cannot be greater than max (1)",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result, err := NewMagicHeader(tt.min, tt.max)
+
+ if tt.errorMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errorMsg)
+ require.Equal(t, MagicHeader{}, result)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestParseMagicHeader(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ value string
+ expected MagicHeader
+ errorMsg string
+ }{
+ {
+ name: "single value",
+ key: "header1",
+ value: "100",
+ expected: MagicHeader{Min: 100, Max: 100},
+ },
+ {
+ name: "valid range",
+ key: "header2",
+ value: "10-20",
+ expected: MagicHeader{Min: 10, Max: 20},
+ },
+ {
+ name: "zero single value",
+ key: "header3",
+ value: "0",
+ expected: MagicHeader{Min: 0, Max: 0},
+ },
+ {
+ name: "zero range",
+ key: "header4",
+ value: "0-0",
+ expected: MagicHeader{Min: 0, Max: 0},
+ },
+ {
+ name: "max uint32 single",
+ key: "header5",
+ value: "4294967295",
+ expected: MagicHeader{Min: 4294967295, Max: 4294967295},
+ },
+ {
+ name: "max uint32 range",
+ key: "header6",
+ value: "4294967294-4294967295",
+ expected: MagicHeader{Min: 4294967294, Max: 4294967295},
+ },
+ {
+ name: "invalid single value - not number",
+ key: "header7",
+ value: "abc",
+ expected: MagicHeader{},
+ errorMsg: "parse key: header7; value: abc;",
+ },
+ {
+ name: "invalid single value - negative",
+ key: "header8",
+ value: "-5",
+ expected: MagicHeader{},
+ errorMsg: "invalid value for key: header8; value: -5;",
+ },
+ {
+ name: "invalid single value - too large",
+ key: "header9",
+ value: "4294967296",
+ expected: MagicHeader{},
+ errorMsg: "parse key: header9; value: 4294967296;",
+ },
+ {
+ name: "invalid range - min not number",
+ key: "header10",
+ value: "abc-10",
+ expected: MagicHeader{},
+ errorMsg: "parse min key: header10; value: abc;",
+ },
+ {
+ name: "invalid range - max not number",
+ key: "header11",
+ value: "10-abc",
+ expected: MagicHeader{},
+ errorMsg: "parse max key: header11; value: abc;",
+ },
+ {
+ name: "invalid range - min greater than max",
+ key: "header12",
+ value: "20-10",
+ expected: MagicHeader{},
+ errorMsg: "new magicHeader key: header12; value: 20-10;",
+ },
+ {
+ name: "invalid range - too many parts",
+ key: "header13",
+ value: "10-20-30",
+ expected: MagicHeader{},
+ errorMsg: "parse key: header13; value: 10-20-30;",
+ },
+ {
+ name: "empty value",
+ key: "header14",
+ value: "",
+ expected: MagicHeader{},
+ errorMsg: "parse key: header14; value: ;",
+ },
+ {
+ name: "hyphen only",
+ key: "header15",
+ value: "-",
+ expected: MagicHeader{},
+ errorMsg: "invalid value for key: header15; value: -;",
+ },
+ {
+ name: "empty min",
+ key: "header16",
+ value: "-10",
+ expected: MagicHeader{},
+ errorMsg: "invalid value for key: header16; value: -10;",
+ },
+ {
+ name: "empty max",
+ key: "header17",
+ value: "10-",
+ expected: MagicHeader{},
+ errorMsg: "invalid value for key: header17; value: 10-;",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result, err := ParseMagicHeader(tt.key, tt.value)
+
+ if tt.errorMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errorMsg)
+ require.Equal(t, MagicHeader{}, result)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestNewMagicHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ magicHeaders []MagicHeader
+ errorMsg string
+ }{
+ {
+ name: "valid non-overlapping headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 11, Max: 20},
+ {Min: 21, Max: 30},
+ {Min: 31, Max: 40},
+ },
+ },
+ {
+ name: "valid adjacent headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 1},
+ {Min: 2, Max: 2},
+ {Min: 3, Max: 3},
+ {Min: 4, Max: 4},
+ },
+ },
+ {
+ name: "valid zero-based headers",
+ magicHeaders: []MagicHeader{
+ {Min: 0, Max: 0},
+ {Min: 1, Max: 1},
+ {Min: 2, Max: 2},
+ {Min: 3, Max: 3},
+ },
+ },
+ {
+ name: "valid large value headers",
+ magicHeaders: []MagicHeader{
+ {Min: 4294967290, Max: 4294967291},
+ {Min: 4294967292, Max: 4294967293},
+ {Min: 4294967294, Max: 4294967294},
+ {Min: 4294967295, Max: 4294967295},
+ },
+ },
+ {
+ name: "too few headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 11, Max: 20},
+ {Min: 21, Max: 30},
+ },
+ errorMsg: "all header types should be included:",
+ },
+ {
+ name: "too many headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 11, Max: 20},
+ {Min: 21, Max: 30},
+ {Min: 31, Max: 40},
+ {Min: 41, Max: 50},
+ },
+ errorMsg: "all header types should be included:",
+ },
+ {
+ name: "empty headers",
+ magicHeaders: []MagicHeader{},
+ errorMsg: "all header types should be included:",
+ },
+ {
+ name: "overlapping headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 15},
+ {Min: 10, Max: 20},
+ {Min: 25, Max: 30},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ {
+ name: "overlapping headers at limit-first",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 10, Max: 20},
+ {Min: 25, Max: 30},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ {
+ name: "overlapping headers at limit-second",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 15, Max: 25},
+ {Min: 25, Max: 30},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ {
+ name: "overlapping headers at limit-third",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 15, Max: 25},
+ {Min: 30, Max: 35},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ {
+ name: "identical ranges",
+ magicHeaders: []MagicHeader{
+ {Min: 10, Max: 20},
+ {Min: 10, Max: 20},
+ {Min: 25, Max: 30},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result, err := NewMagicHeaders(tt.magicHeaders)
+
+ if tt.errorMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errorMsg)
+ require.Equal(t, MagicHeaders{}, result)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tt.magicHeaders, result.Values)
+ require.NotNil(t, result.randomGenerator)
+ }
+ })
+ }
+}
+
+// Mock PRNG for testing
+type mockPRNG struct {
+ returnValue uint32
+}
+
+func (m *mockPRNG) RandomSizeInRange(min, max uint32) uint32 {
+ return m.returnValue
+}
+
+func (m *mockPRNG) Get() uint64 {
+ return 0
+}
+func (m *mockPRNG) ReadSize(size int) []byte {
+ return make([]byte, size)
+}
+
+func TestMagicHeaders_Get(t *testing.T) {
+ // Create test headers
+ headers := []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 11, Max: 20},
+ {Min: 21, Max: 30},
+ {Min: 31, Max: 40},
+ }
+
+ tests := []struct {
+ name string
+ defaultMsgType uint32
+ mockValue uint32
+ expectedValue uint32
+ errorMsg string
+ }{
+ {
+ name: "valid type 1",
+ defaultMsgType: 1,
+ mockValue: 5,
+ expectedValue: 5,
+ },
+ {
+ name: "valid type 2",
+ defaultMsgType: 2,
+ mockValue: 15,
+ expectedValue: 15,
+ },
+ {
+ name: "valid type 3",
+ defaultMsgType: 3,
+ mockValue: 25,
+ expectedValue: 25,
+ },
+ {
+ name: "valid type 4",
+ defaultMsgType: 4,
+ mockValue: 35,
+ expectedValue: 35,
+ },
+ {
+ name: "invalid type 0",
+ defaultMsgType: 0,
+ mockValue: 0,
+ expectedValue: 0,
+ errorMsg: "invalid msg type: 0",
+ },
+ {
+ name: "invalid type 5",
+ defaultMsgType: 5,
+ mockValue: 0,
+ expectedValue: 0,
+ errorMsg: "invalid msg type: 5",
+ },
+ {
+ name: "invalid type max uint32",
+ defaultMsgType: 4294967295,
+ mockValue: 0,
+ expectedValue: 0,
+ errorMsg: "invalid msg type: 4294967295",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ // Create a new instance with mock PRNG for each test
+ testMagicHeaders := MagicHeaders{
+ Values: headers,
+ randomGenerator: &mockPRNG{returnValue: tt.mockValue},
+ }
+
+ result, err := testMagicHeaders.Get(tt.defaultMsgType)
+
+ if tt.errorMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errorMsg)
+ require.Equal(t, uint32(0), result)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tt.expectedValue, result)
+ }
+ })
+ }
+}
diff --git a/device/awg/prng.go b/device/awg/prng.go
new file mode 100644
index 0000000..e7661d7
--- /dev/null
+++ b/device/awg/prng.go
@@ -0,0 +1,50 @@
+package awg
+
+import (
+ crand "crypto/rand"
+ v2 "math/rand/v2"
+
+ "golang.org/x/exp/constraints"
+)
+
+type RandomNumberGenerator[T constraints.Integer] interface {
+ RandomSizeInRange(min, max T) T
+ Get() uint64
+ ReadSize(size int) []byte
+}
+
+type PRNG[T constraints.Integer] struct {
+ cha8Rand *v2.ChaCha8
+}
+
+func NewPRNG[T constraints.Integer]() PRNG[T] {
+ buf := make([]byte, 32)
+ _, _ = crand.Read(buf)
+
+ return PRNG[T]{
+ cha8Rand: v2.NewChaCha8([32]byte(buf)),
+ }
+}
+
+func (p PRNG[T]) RandomSizeInRange(min, max T) T {
+ if min > max {
+ panic("min must be less than max")
+ }
+
+ if min == max {
+ return min
+ }
+
+ return T(p.Get()%uint64(max-min)) + min
+}
+
+func (p PRNG[T]) Get() uint64 {
+ return p.cha8Rand.Uint64()
+}
+
+func (p PRNG[T]) ReadSize(size int) []byte {
+ // TODO: use a memory pool to allocate
+ buf := make([]byte, size)
+ _, _ = p.cha8Rand.Read(buf)
+ return buf
+}
diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go
index 2b09226..c2e1090 100644
--- a/device/awg/tag_parser.go
+++ b/device/awg/tag_parser.go
@@ -55,7 +55,7 @@ func parseTag(input string) (Tag, error) {
return tag, nil
}
-func Parse(name, input string) (TagJunkPacketGenerator, error) {
+func ParseTagJunkGenerator(name, input string) (TagJunkPacketGenerator, error) {
inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)
diff --git a/device/awg/tag_parser_test.go b/device/awg/tag_parser_test.go
index 8f828ec..3229cee 100644
--- a/device/awg/tag_parser_test.go
+++ b/device/awg/tag_parser_test.go
@@ -64,7 +64,7 @@ func TestParse(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- _, err := Parse(tt.args.name, tt.args.input)
+ _, err := ParseTagJunkGenerator(tt.args.name, tt.args.input)
// TODO: ErrorAs doesn't work as you think
if tt.wantErr != nil {
diff --git a/device/cookie.go b/device/cookie.go
index a093c8b..5e09806 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -118,6 +118,7 @@ func (st *CookieChecker) CreateReply(
msg []byte,
recv uint32,
src []byte,
+ msgType uint32,
) (*MessageCookieReply, error) {
st.RLock()
@@ -153,7 +154,7 @@ func (st *CookieChecker) CreateReply(
smac1 := smac2 - blake2s.Size128
reply := new(MessageCookieReply)
- reply.Type = MessageCookieReplyType
+ reply.Type = msgType
reply.Receiver = recv
_, err := rand.Read(reply.Nonce[:])
diff --git a/device/cookie_test.go b/device/cookie_test.go
index c937290..e5a2bd4 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -99,7 +99,7 @@ func TestCookieMAC1(t *testing.T) {
0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
}
generator.AddMacs(msg)
- reply, err := checker.CreateReply(msg, 1377, src)
+ reply, err := checker.CreateReply(msg, 1377, src, DefaultMessageCookieReplyType)
if err != nil {
t.Fatal("Failed to create cookie reply:", err)
}
diff --git a/device/device.go b/device/device.go
index 1829352..fdb8620 100644
--- a/device/device.go
+++ b/device/device.go
@@ -6,7 +6,9 @@
package device
import (
+ "encoding/binary"
"errors"
+ "fmt"
"runtime"
"sync"
"sync/atomic"
@@ -578,6 +580,7 @@ func (device *Device) BindClose() error {
device.net.Unlock()
return err
}
+
func (device *Device) isAWG() bool {
return device.version >= VersionAwg
}
@@ -591,171 +594,123 @@ func (device *Device) resetProtocol() {
}
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
- if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
+ if !tempAwg.Cfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
return nil
}
var errs []error
- isASecOn := false
- device.awg.ASecMux.Lock()
- if tempAwg.ASecCfg.JunkPacketCount < 0 {
+ isAwgOn := false
+ device.awg.Mux.Lock()
+ if tempAwg.Cfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
),
)
}
- device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount
- if tempAwg.ASecCfg.JunkPacketCount != 0 {
- isASecOn = true
+ device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
+ if tempAwg.Cfg.JunkPacketCount != 0 {
+ isAwgOn = true
}
- device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize
- if tempAwg.ASecCfg.JunkPacketMinSize != 0 {
- isASecOn = true
+ device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
+ if tempAwg.Cfg.JunkPacketMinSize != 0 {
+ isAwgOn = true
}
- if device.awg.ASecCfg.JunkPacketCount > 0 &&
- tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize {
+ if device.awg.Cfg.JunkPacketCount > 0 &&
+ tempAwg.Cfg.JunkPacketMaxSize == tempAwg.Cfg.JunkPacketMinSize {
- tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work
+ tempAwg.Cfg.JunkPacketMaxSize++ // to make rand gen work
}
- if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize {
- device.awg.ASecCfg.JunkPacketMinSize = 0
- device.awg.ASecCfg.JunkPacketMaxSize = 1
+ if tempAwg.Cfg.JunkPacketMaxSize >= MaxSegmentSize {
+ device.awg.Cfg.JunkPacketMinSize = 0
+ device.awg.Cfg.JunkPacketMaxSize = 1
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
- tempAwg.ASecCfg.JunkPacketMaxSize,
+ tempAwg.Cfg.JunkPacketMaxSize,
MaxSegmentSize,
))
- } else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize {
+ } else if tempAwg.Cfg.JunkPacketMaxSize < tempAwg.Cfg.JunkPacketMinSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
- tempAwg.ASecCfg.JunkPacketMaxSize,
- tempAwg.ASecCfg.JunkPacketMinSize,
+ tempAwg.Cfg.JunkPacketMaxSize,
+ tempAwg.Cfg.JunkPacketMinSize,
))
} else {
- device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize
+ device.awg.Cfg.JunkPacketMaxSize = tempAwg.Cfg.JunkPacketMaxSize
}
- if tempAwg.ASecCfg.JunkPacketMaxSize != 0 {
- isASecOn = true
+ if tempAwg.Cfg.JunkPacketMaxSize != 0 {
+ isAwgOn = true
}
- newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
+ magicHeaders := make([]awg.MagicHeader, 4)
- if newInitSize >= MaxSegmentSize {
- errs = append(errs, ipcErrorf(
+ if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
+ return ipcErrorf(
ipc.IpcErrorInvalid,
- `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempAwg.ASecCfg.InitHeaderJunkSize,
- MaxSegmentSize,
- ),
+ "magic headers should have 4 values; got: %d",
+ len(tempAwg.Cfg.MagicHeaders.Values),
)
- } else {
- device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
}
- if tempAwg.ASecCfg.InitHeaderJunkSize != 0 {
- isASecOn = true
- }
-
- newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
-
- if newResponseSize >= MaxSegmentSize {
- errs = append(errs, ipcErrorf(
- ipc.IpcErrorInvalid,
- `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempAwg.ASecCfg.ResponseHeaderJunkSize,
- MaxSegmentSize,
- ),
- )
- } else {
- device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
- }
-
- if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
- isASecOn = true
- }
-
- newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
-
- if newCookieSize >= MaxSegmentSize {
- errs = append(errs, ipcErrorf(
- ipc.IpcErrorInvalid,
- `cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
- MaxSegmentSize,
- ),
- )
- } else {
- device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
- }
-
- if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
- isASecOn = true
- }
-
- newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
-
- if newTransportSize >= MaxSegmentSize {
- errs = append(errs, ipcErrorf(
- ipc.IpcErrorInvalid,
- `transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
- tempAwg.ASecCfg.TransportHeaderJunkSize,
- MaxSegmentSize,
- ),
- )
- } else {
- device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
- }
-
- if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
- isASecOn = true
- }
-
- if tempAwg.ASecCfg.InitPacketMagicHeader > 4 {
- isASecOn = true
+ if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
+ isAwgOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
- device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
- MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader
+ magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
+
+ MessageInitiationType = magicHeaders[0].Min
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
+ magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
}
- if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 {
- isASecOn = true
+ if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
+ isAwgOn = true
+
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
- device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
- MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader
+ magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
+ MessageResponseType = magicHeaders[1].Min
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
+ magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
}
- if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 {
- isASecOn = true
+ if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
+ isAwgOn = true
+
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
- device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
- MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader
+ magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
+ MessageCookieReplyType = magicHeaders[2].Min
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
+ magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
}
- if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 {
- isASecOn = true
+ if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
+ isAwgOn = true
+
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
- device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
- MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader
+ magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
+ MessageTransportType = magicHeaders[3].Min
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
+ magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
+ }
+
+ var err error
+ device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
+ if err != nil {
+ errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
}
isSameHeaderMap := map[uint32]struct{}{
@@ -778,6 +733,78 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
}
+ newInitSize := MessageInitiationSize + tempAwg.Cfg.InitHeaderJunkSize
+
+ if newInitSize >= MaxSegmentSize {
+ errs = append(errs, ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
+ tempAwg.Cfg.InitHeaderJunkSize,
+ MaxSegmentSize,
+ ),
+ )
+ } else {
+ device.awg.Cfg.InitHeaderJunkSize = tempAwg.Cfg.InitHeaderJunkSize
+ }
+
+ if tempAwg.Cfg.InitHeaderJunkSize != 0 {
+ isAwgOn = true
+ }
+
+ newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
+
+ if newResponseSize >= MaxSegmentSize {
+ errs = append(errs, ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
+ tempAwg.Cfg.ResponseHeaderJunkSize,
+ MaxSegmentSize,
+ ),
+ )
+ } else {
+ device.awg.Cfg.ResponseHeaderJunkSize = tempAwg.Cfg.ResponseHeaderJunkSize
+ }
+
+ if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
+ isAwgOn = true
+ }
+
+ newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
+
+ if newCookieSize >= MaxSegmentSize {
+ errs = append(errs, ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
+ tempAwg.Cfg.CookieReplyHeaderJunkSize,
+ MaxSegmentSize,
+ ),
+ )
+ } else {
+ device.awg.Cfg.CookieReplyHeaderJunkSize = tempAwg.Cfg.CookieReplyHeaderJunkSize
+ }
+
+ if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
+ isAwgOn = true
+ }
+
+ newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
+
+ if newTransportSize >= MaxSegmentSize {
+ errs = append(errs, ipcErrorf(
+ ipc.IpcErrorInvalid,
+ `transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
+ tempAwg.Cfg.TransportHeaderJunkSize,
+ MaxSegmentSize,
+ ),
+ )
+ } else {
+ device.awg.Cfg.TransportHeaderJunkSize = tempAwg.Cfg.TransportHeaderJunkSize
+ }
+
+ if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
+ isAwgOn = true
+ }
+
isSameSizeMap := map[int]struct{}{
newInitSize: {},
newResponseSize: {},
@@ -797,10 +824,10 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
} else {
msgTypeToJunkSize = map[uint32]int{
- MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize,
- MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize,
- MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize,
- MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize,
+ MessageInitiationType: device.awg.Cfg.InitHeaderJunkSize,
+ MessageResponseType: device.awg.Cfg.ResponseHeaderJunkSize,
+ MessageCookieReplyType: device.awg.Cfg.CookieReplyHeaderJunkSize,
+ MessageTransportType: device.awg.Cfg.TransportHeaderJunkSize,
}
packetSizeToMsgType = map[int]uint32{
@@ -811,12 +838,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
}
- device.awg.IsASecOn.SetTo(isASecOn)
- var err error
- device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)
- if err != nil {
- errs = append(errs, err)
- }
+ device.awg.IsOn.SetTo(isAwgOn)
+ device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
@@ -824,15 +847,92 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
- device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
- device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
+ device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
+ device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
device.version = VersionAwgSpecialHandshake
}
} else {
device.version = VersionAwg
}
- device.awg.ASecMux.Unlock()
+ device.awg.Mux.Unlock()
return errors.Join(errs...)
}
+
+func (device *Device) ProcessAWGPacket(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
+ // TODO:
+ // if awg.WaitResponse.ShouldWait.IsSet() {
+ // awg.WaitResponse.Channel <- struct{}{}
+ // }
+
+ expectedMsgType, isKnownSize := packetSizeToMsgType[size]
+ if !isKnownSize {
+ msgType, err := device.handleTransport(size, packet, buffer)
+
+ if err != nil {
+ return 0, fmt.Errorf("handle transport: %w", err)
+ }
+
+ return msgType, nil
+ }
+
+ junkSize := msgTypeToJunkSize[expectedMsgType]
+
+ // transport size can align with other header types;
+ // making sure we have the right actualMsgType
+ actualMsgType, err := device.getMsgType(packet, junkSize)
+ if err != nil {
+ return 0, fmt.Errorf("get msg type: %w", err)
+ }
+
+ if actualMsgType == expectedMsgType {
+ *packet = (*packet)[junkSize:]
+ return actualMsgType, nil
+ }
+
+ device.log.Verbosef("awg: transport packet lined up with another msg type")
+
+ msgType, err := device.handleTransport(size, packet, buffer)
+ if err != nil {
+ return 0, fmt.Errorf("handle transport: %w", err)
+ }
+
+ return msgType, nil
+}
+
+func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
+ msgTypeValue := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
+ msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeValue)
+
+ if err != nil {
+ return 0, fmt.Errorf("get magic header min: %w", err)
+ }
+
+ return msgType, nil
+}
+
+func (device *Device) handleTransport(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
+ junkSize := device.awg.Cfg.TransportHeaderJunkSize
+
+ msgType, err := device.getMsgType(packet, junkSize)
+ if err != nil {
+ return 0, fmt.Errorf("get msg type: %w", err)
+ }
+
+ if msgType != MessageTransportType {
+ // probably a junk packet
+ return 0, fmt.Errorf("Received message with unknown type: %d", msgType)
+ }
+
+ if junkSize > 0 {
+ // remove junk from buffer by shifting the packet
+ // this buffer is also used for decryption, so it needs to be corrected
+ copy((*buffer)[:size], (*packet)[junkSize:])
+ size -= junkSize
+ // need to reinitialize packet as well
+ (*packet) = (*packet)[:size]
+ }
+
+ return msgType, nil
+}
diff --git a/device/device_test.go b/device/device_test.go
index 5824cf9..5820128 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -232,14 +232,14 @@ func TestAWGDevicePing(t *testing.T) {
"jc", "5",
"jmin", "500",
"jmax", "1000",
- "s1", "30",
- "s2", "40",
- "s3", "50",
- "s4", "5",
- "h1", "123456",
- "h2", "67543",
- "h3", "123123",
- "h4", "32345",
+ "s1", "15",
+ "s2", "18",
+ "s3", "20",
+ "s4", "25",
+ "h1", "123456-123500",
+ "h2", "67543-67550",
+ "h3", "123123-123200",
+ "h4", "32345-32350",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index f637b24..7d15d87 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -205,12 +205,18 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
+ msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType)
+ if err != nil {
+ device.awg.Mux.RUnlock()
+ return nil, fmt.Errorf("get message type: %w", err)
+ }
+
msg := MessageInitiation{
- Type: MessageInitiationType,
+ Type: msgType,
Ephemeral: handshake.localEphemeral.publicKey(),
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
@@ -264,12 +270,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
+
if msg.Type != MessageInitiationType {
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
return nil
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -384,9 +391,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
- device.awg.ASecMux.RLock()
- msg.Type = MessageResponseType
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RLock()
+ msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
+ if err != nil {
+ device.awg.Mux.RUnlock()
+ return nil, fmt.Errorf("get message type: %w", err)
+ }
+
+ device.awg.Mux.RUnlock()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@@ -436,12 +448,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
+
if msg.Type != MessageResponseType {
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
return nil
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
// lookup handshake by receiver
diff --git a/device/receive.go b/device/receive.go
index 6daba0d..4c34799 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@@ -140,37 +140,11 @@ func (device *Device) RoutineReceiveIncoming(
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAWG() {
- // TODO:
- // if awg.WaitResponse.ShouldWait.IsSet() {
- // awg.WaitResponse.Channel <- struct{}{}
- // }
+ msgType, err = device.ProcessAWGPacket(size, &packet, bufsArrs[i])
- if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
- junkSize := msgTypeToJunkSize[assumedMsgType]
- // transport size can align with other header types;
- // making sure we have the right msgType
- msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
- if msgType == assumedMsgType {
- packet = packet[junkSize:]
- } else {
- device.log.Verbosef("transport packet lined up with another msg type")
- msgType = binary.LittleEndian.Uint32(packet[:4])
- }
- } else {
- transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
- msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4])
- if msgType != MessageTransportType {
- // probably a junk packet
- device.log.Verbosef("aSec: Received message with unknown type: %d", msgType)
- continue
- }
-
- // remove junk from bufsArrs by shifting the packet
- // this buffer is also used for decryption, so it needs to be corrected
- copy(bufsArrs[i][:size], packet[transportJunkSize:])
- size -= transportJunkSize
- // need to reinitialize packet as well
- packet = packet[:size]
+ if err != nil {
+ device.log.Verbosef("awg: process packet: %v", err)
+ continue
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
@@ -259,7 +233,7 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@@ -318,7 +292,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c {
- device.awg.ASecMux.RLock()
+ device.awg.Mux.RLock()
// handle cookie fields and ratelimiting
@@ -405,6 +379,9 @@ func (device *Device) RoutineHandshake(id int) {
goto skip
}
+ // have to reassign msgType for ranged msgType to work
+ msg.Type = elem.msgType
+
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
@@ -437,6 +414,9 @@ func (device *Device) RoutineHandshake(id int) {
goto skip
}
+ // have to reassign msgType for ranged msgType to work
+ msg.Type = elem.msgType
+
// consume response
peer := device.ConsumeMessageResponse(&msg)
@@ -470,7 +450,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
- device.awg.ASecMux.RUnlock()
+ device.awg.Mux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
diff --git a/device/send.go b/device/send.go
index 04ca2ad..e1aec72 100644
--- a/device/send.go
+++ b/device/send.go
@@ -130,7 +130,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if peer.device.version >= VersionAwg {
var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake {
- peer.device.awg.ASecMux.RLock()
+ peer.device.awg.Mux.RLock()
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil {
@@ -141,18 +141,13 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
} else {
peer.device.log.Verbosef("%v - Special junks sent", peer)
}
- peer.device.awg.ASecMux.RUnlock()
+ peer.device.awg.Mux.RUnlock()
} else {
- junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount)
- }
- peer.device.awg.ASecMux.RLock()
- err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
- peer.device.awg.ASecMux.RUnlock()
-
- if err != nil {
- peer.device.log.Errorf("%v - %v", peer, err)
- return err
+ junks = make([][]byte, 0, peer.device.awg.Cfg.JunkPacketCount)
}
+ peer.device.awg.Mux.RLock()
+ peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
+ peer.device.awg.Mux.RUnlock()
if len(junks) > 0 {
err = peer.SendBuffers(junks)
@@ -242,10 +237,17 @@ func (device *Device) SendHandshakeCookie(
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
+ msgType, err := device.awg.GetMsgType(DefaultMessageCookieReplyType)
+ if err != nil {
+ device.log.Errorf("Get message type for cookie reply: %v", err)
+ return err
+ }
+
reply, err := device.cookieChecker.CreateReply(
initiatingElem.packet,
sender,
initiatingElem.endpoint.DstToBytes(),
+ msgType,
)
if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err)
@@ -528,7 +530,12 @@ func (device *Device) RoutineEncryption(id int) {
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
- binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ msgType, err := device.awg.GetMsgType(DefaultMessageTransportType)
+ if err != nil {
+ device.log.Errorf("get message type for transport: %v", err)
+ continue
+ }
+ binary.LittleEndian.PutUint32(fieldType, msgType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
diff --git a/device/uapi.go b/device/uapi.go
index e9f962a..d7b74cd 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -99,38 +99,36 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
if device.isAWG() {
- if device.awg.ASecCfg.JunkPacketCount != 0 {
- sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount)
+ if device.awg.Cfg.JunkPacketCount != 0 {
+ sendf("jc=%d", device.awg.Cfg.JunkPacketCount)
}
- if device.awg.ASecCfg.JunkPacketMinSize != 0 {
- sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize)
+ if device.awg.Cfg.JunkPacketMinSize != 0 {
+ sendf("jmin=%d", device.awg.Cfg.JunkPacketMinSize)
}
- if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
- sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
+ if device.awg.Cfg.JunkPacketMaxSize != 0 {
+ sendf("jmax=%d", device.awg.Cfg.JunkPacketMaxSize)
}
- if device.awg.ASecCfg.InitHeaderJunkSize != 0 {
- sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize)
+ if device.awg.Cfg.InitHeaderJunkSize != 0 {
+ sendf("s1=%d", device.awg.Cfg.InitHeaderJunkSize)
}
- if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 {
- sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize)
+ if device.awg.Cfg.ResponseHeaderJunkSize != 0 {
+ sendf("s2=%d", device.awg.Cfg.ResponseHeaderJunkSize)
}
- if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
- sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize)
+ if device.awg.Cfg.CookieReplyHeaderJunkSize != 0 {
+ sendf("s3=%d", device.awg.Cfg.CookieReplyHeaderJunkSize)
}
- if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
- sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
+ if device.awg.Cfg.TransportHeaderJunkSize != 0 {
+ sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
}
- if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
- sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
- }
- if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
- sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
- }
- if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
- sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
- }
- if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
- sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
+ for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
+ if magicHeader.Min > 4 {
+ if magicHeader.Min == magicHeader.Max {
+ sendf("h%d=%d", i+1, magicHeader.Min)
+ continue
+ }
+
+ sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
+ }
}
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
@@ -200,6 +198,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
deviceConfig := true
tempAwg := awg.Protocol{}
+ tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
+
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
@@ -312,8 +312,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
- tempAwg.ASecCfg.JunkPacketCount = junkPacketCount
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.JunkPacketCount = junkPacketCount
+ tempAwg.Cfg.IsSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
@@ -321,8 +321,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
- tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.JunkPacketMinSize = junkPacketMinSize
+ tempAwg.Cfg.IsSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
@@ -330,8 +330,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
- tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.JunkPacketMaxSize = junkPacketMaxSize
+ tempAwg.Cfg.IsSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
@@ -339,8 +339,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
- tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.InitHeaderJunkSize = initPacketJunkSize
+ tempAwg.Cfg.IsSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
@@ -348,8 +348,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
- tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.ResponseHeaderJunkSize = responsePacketJunkSize
+ tempAwg.Cfg.IsSet = true
case "s3":
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
@@ -357,8 +357,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
- tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
+ tempAwg.Cfg.IsSet = true
case "s4":
transportPacketJunkSize, err := strconv.Atoi(value)
@@ -366,47 +366,47 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
- tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize
- tempAwg.ASecCfg.IsSet = true
-
+ tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
+ tempAwg.Cfg.IsSet = true
case "h1":
- initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ initMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err)
+ return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader)
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
+ tempAwg.Cfg.IsSet = true
case "h2":
- responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ responseMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err)
+ return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader)
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
+ tempAwg.Cfg.IsSet = true
case "h3":
- underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err)
+ return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
- tempAwg.ASecCfg.IsSet = true
+ tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
+ tempAwg.Cfg.IsSet = true
case "h4":
- transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
+ transportMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err)
+ return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
}
- tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader)
- tempAwg.ASecCfg.IsSet = true
+
+ tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
+ tempAwg.Cfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)
return nil
}
- generators, err := awg.Parse(key, value)
+ generators, err := awg.ParseTagJunkGenerator(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
@@ -419,7 +419,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return nil
}
- generators, err := awg.Parse(key, value)
+ generators, err := awg.ParseTagJunkGenerator(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
diff --git a/go.mod b/go.mod
index 5e5f34d..8c4372d 100644
--- a/go.mod
+++ b/go.mod
@@ -5,9 +5,9 @@ go 1.24.4
require (
github.com/stretchr/testify v1.10.0
github.com/tevino/abool v1.2.0
- github.com/tevino/abool/v2 v2.1.0
go.uber.org/atomic v1.11.0
golang.org/x/crypto v0.39.0
+ golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.41.0
golang.org/x/sys v0.33.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
diff --git a/go.sum b/go.sum
index 6b8f36b..3d8b3c2 100644
--- a/go.sum
+++ b/go.sum
@@ -2,24 +2,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
-github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
-github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
-github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
-github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
-github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
-golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
-golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
-golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
+golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
+golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
@@ -34,7 +28,3 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE=
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk=
-gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=
-gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
-gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM=
-gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=