feat: ranged magic headers

This commit is contained in:
Mark Puha 2025-07-15 19:05:22 +02:00
parent 1abd24b5b9
commit 15d7259cd4
18 changed files with 1084 additions and 459 deletions

View file

@ -3,142 +3,88 @@ package awg
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"slices"
"strconv"
"strings"
"sync" "sync"
"github.com/tevino/abool" "github.com/tevino/abool"
) )
type aSecCfgType struct { type Cfg struct {
IsSet bool IsSet bool
JunkPacketCount int JunkPacketCount int
JunkPacketMinSize int JunkPacketMinSize int
JunkPacketMaxSize int JunkPacketMaxSize int
InitHeaderJunkSize int InitHeaderJunkSize int
ResponseHeaderJunkSize int ResponseHeaderJunkSize int
CookieReplyHeaderJunkSize int CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int TransportHeaderJunkSize int
InitPacketMagicHeader uint32
ResponsePacketMagicHeader uint32
UnderloadPacketMagicHeader uint32
TransportPacketMagicHeader uint32
// InitPacketMagicHeader Limit
// ResponsePacketMagicHeader Limit
// UnderloadPacketMagicHeader Limit
// TransportPacketMagicHeader Limit
}
type Limit struct { MagicHeaders MagicHeaders
Min uint32
Max uint32
HeaderType uint32
}
func NewLimit(min, max, headerType uint32) (Limit, error) {
if min > max {
return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return Limit{
Min: min,
Max: max,
HeaderType: headerType,
}, nil
}
func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) {
// tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType)
// var min, max, headerType uint32
// _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType)
// if err != nil {
// return Limit{}, fmt.Errorf("invalid magic header format: %s", value)
// }
limits := strings.Split(value, "-")
if len(limits) != 2 {
return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value)
}
min, err := strconv.ParseUint(limits[0], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err)
}
max, err := strconv.ParseUint(limits[1], 10, 32)
if err != nil {
return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err)
}
limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType)
if err != nil {
return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err)
}
return limit, nil
}
type Limits []Limit
func NewLimits(limits []Limit) Limits {
slices.SortFunc(limits, func(a, b Limit) int {
if a.Min < b.Min {
return -1
} else if a.Min > b.Min {
return 1
}
return 0
})
return Limits(limits)
} }
type Protocol struct { type Protocol struct {
IsASecOn abool.AtomicBool IsOn abool.AtomicBool
// TODO: revision the need of the mutex // TODO: revision the need of the mutex
ASecMux sync.RWMutex Mux sync.RWMutex
ASecCfg aSecCfgType Cfg Cfg
JunkCreator junkCreator JunkCreator JunkCreator
HandshakeHandler SpecialHandshakeHandler HandshakeHandler SpecialHandshakeHandler
} }
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize) protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.Cfg.InitHeaderJunkSize, 0)
} }
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) { func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize) protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.Cfg.ResponseHeaderJunkSize, 0)
} }
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) { func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize) protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.Cfg.CookieReplyHeaderJunkSize, 0)
} }
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) { func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize) protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.Cfg.TransportHeaderJunkSize, packetSize)
} }
func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) { func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
extraSize := 0 if junkSize == 0 {
if len(optExtraSize) == 1 { return nil, nil
extraSize = optExtraSize[0]
} }
var junk []byte buf := make([]byte, 0, junkSize+extraSize)
protocol.ASecMux.RLock() writer := bytes.NewBuffer(buf[:0])
if junkSize != 0 {
buf := make([]byte, 0, junkSize+extraSize) err := protocol.JunkCreator.AppendJunk(writer, junkSize)
writer := bytes.NewBuffer(buf[:0]) if err != nil {
err := protocol.JunkCreator.AppendJunk(writer, junkSize) return nil, fmt.Errorf("append junk: %w", err)
if err != nil { }
protocol.ASecMux.RUnlock()
return nil, err return writer.Bytes(), nil
}
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
return magicHeader.Min, nil
} }
junk = writer.Bytes()
} }
protocol.ASecMux.RUnlock()
return junk, nil return 0, fmt.Errorf("no header for value: %d", msgType)
}
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
} }

View file

@ -2,69 +2,49 @@ package awg
import ( import (
"bytes" "bytes"
crand "crypto/rand"
"fmt" "fmt"
v2 "math/rand/v2"
) )
type junkCreator struct { type JunkCreator struct {
aSecCfg aSecCfgType cfg Cfg
cha8Rand *v2.ChaCha8 randomGenerator PRNG[int]
} }
// TODO: refactor param to only pass the junk related params // TODO: refactor param to only pass the junk related params
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) { func NewJunkCreator(cfg Cfg) JunkCreator {
buf := make([]byte, 32) return JunkCreator{cfg: cfg, randomGenerator: NewPRNG[int]()}
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
} }
// Should be called with aSecMux RLocked // Should be called with awg mux RLocked
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error { func (jc *JunkCreator) CreateJunkPackets(junks *[][]byte) {
if jc.aSecCfg.JunkPacketCount == 0 { if jc.cfg.JunkPacketCount == 0 {
return nil return
} }
for range jc.aSecCfg.JunkPacketCount { for range jc.cfg.JunkPacketCount {
packetSize := jc.randomPacketSize() packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize) junk := jc.randomJunkWithSize(packetSize)
if err != nil {
return fmt.Errorf("create junk packet: %v", err)
}
*junks = append(*junks, junk) *junks = append(*junks, junk)
} }
return nil return
} }
// Should be called with aSecMux RLocked // Should be called with awg mux RLocked
func (jc *junkCreator) randomPacketSize() int { func (jc *JunkCreator) randomPacketSize() int {
return int( return jc.randomGenerator.RandomSizeInRange(jc.cfg.JunkPacketMinSize, jc.cfg.JunkPacketMaxSize)
jc.cha8Rand.Uint64()%uint64(
jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
),
) + jc.aSecCfg.JunkPacketMinSize
} }
// Should be called with aSecMux RLocked // Should be called with awg mux RLocked
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { func (jc *JunkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size) headerJunk := jc.randomJunkWithSize(size)
if err != nil { _, err := writer.Write(headerJunk)
return fmt.Errorf("create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil { if err != nil {
return fmt.Errorf("write header junk: %v", err) return fmt.Errorf("write header junk: %v", err)
} }
return nil return nil
} }
// Should be called with aSecMux RLocked // Should be called with awg mux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { func (jc *JunkCreator) randomJunkWithSize(size int) []byte {
// TODO: use a memory pool to allocate return jc.randomGenerator.ReadSize(size)
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
} }

View file

@ -6,43 +6,29 @@ import (
"testing" "testing"
) )
func setUpJunkCreator(t *testing.T) (junkCreator, error) { func setUpJunkCreator() JunkCreator {
jc, err := NewJunkCreator(aSecCfgType{ jc := NewJunkCreator(Cfg{
IsSet: true, IsSet: true,
JunkPacketCount: 5, JunkPacketCount: 5,
JunkPacketMinSize: 500, JunkPacketMinSize: 500,
JunkPacketMaxSize: 1000, JunkPacketMaxSize: 1000,
InitHeaderJunkSize: 30, InitHeaderJunkSize: 30,
ResponseHeaderJunkSize: 40, ResponseHeaderJunkSize: 40,
InitPacketMagicHeader: 123456, // TODO
ResponsePacketMagicHeader: 67543, // InitPacketMagicHeader: 123456,
UnderloadPacketMagicHeader: 32345, // ResponsePacketMagicHeader: 67543,
TransportPacketMagicHeader: 123123, // UnderloadPacketMagicHeader: 32345,
// TransportPacketMagicHeader: 123123,
}) })
if err != nil { return jc
t.Errorf("failed to create junk creator %v", err)
return junkCreator{}, err
}
return jc, nil
} }
func Test_junkCreator_createJunkPackets(t *testing.T) { func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t) jc := setUpJunkCreator()
if err != nil {
return
}
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount) got := make([][]byte, 0, jc.cfg.JunkPacketCount)
err := jc.CreateJunkPackets(&got) jc.CreateJunkPackets(&got)
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool) seen := make(map[string]bool)
for _, junk := range got { for _, junk := range got {
key := string(junk) key := string(junk)
@ -61,34 +47,28 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
func Test_junkCreator_randomJunkWithSize(t *testing.T) { func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
jc, err := setUpJunkCreator(t) jc := setUpJunkCreator()
if err != nil { r1 := jc.randomJunkWithSize(10)
return r2 := jc.randomJunkWithSize(10)
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2) fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) { if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err) t.Errorf("same junks")
return return
} }
}) })
} }
func Test_junkCreator_randomPacketSize(t *testing.T) { func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t) jc := setUpJunkCreator()
if err != nil {
return
}
for range [30]struct{}{} { for range [30]struct{}{} {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got || if got := jc.randomPacketSize(); jc.cfg.JunkPacketMinSize > got ||
got > jc.aSecCfg.JunkPacketMaxSize { got > jc.cfg.JunkPacketMaxSize {
t.Errorf( t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]", "junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got, got,
jc.aSecCfg.JunkPacketMinSize, jc.cfg.JunkPacketMinSize,
jc.aSecCfg.JunkPacketMaxSize, jc.cfg.JunkPacketMaxSize,
) )
} }
}) })
@ -96,10 +76,7 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
} }
func Test_junkCreator_appendJunk(t *testing.T) { func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t) jc := setUpJunkCreator()
if err != nil {
return
}
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
s := "apple" s := "apple"
buffer := bytes.NewBuffer([]byte(s)) buffer := bytes.NewBuffer([]byte(s))

View file

@ -0,0 +1,93 @@
package awg
import (
"cmp"
"fmt"
"slices"
"strconv"
"strings"
)
type MagicHeader struct {
Min uint32
Max uint32
}
func NewMagicHeaderSameValue(value uint32) MagicHeader {
return MagicHeader{Min: value, Max: value}
}
func NewMagicHeader(min, max uint32) (MagicHeader, error) {
if min > max {
return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return MagicHeader{Min: min, Max: max}, nil
}
func ParseMagicHeader(key, value string) (MagicHeader, error) {
splitLimits := strings.Split(value, "-")
if len(splitLimits) != 2 {
// if there is no hyphen, we treat it as single magic header value
magicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
}
return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
} else if len(splitLimits[0]) == 0 || len(splitLimits[1]) == 0 {
return MagicHeader{}, fmt.Errorf("invalid value for key: %s; value: %s; expected format: min-max", key, value)
}
min, err := strconv.ParseUint(splitLimits[0], 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, splitLimits[0], err)
}
max, err := strconv.ParseUint(splitLimits[1], 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, splitLimits[1], err)
}
magicHeader, err := NewMagicHeader(uint32(min), uint32(max))
if err != nil {
return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, splitLimits[0], splitLimits[1], err)
}
return magicHeader, nil
}
type MagicHeaders struct {
Values []MagicHeader
randomGenerator RandomNumberGenerator[uint32]
}
func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
if len(headerValues) != 4 {
return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", headerValues)
}
sortedMagicHeaders := slices.SortedFunc(slices.Values(headerValues), func(lhs MagicHeader, rhs MagicHeader) int {
return cmp.Compare(lhs.Min, rhs.Min)
})
for i := range 3 {
if sortedMagicHeaders[i].Max >= sortedMagicHeaders[i+1].Min {
return MagicHeaders{}, fmt.Errorf(
"magic headers shouldn't overlap; %v > %v",
sortedMagicHeaders[i].Max,
sortedMagicHeaders[i+1].Min,
)
}
}
return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
}
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
if defaultMsgType == 0 || defaultMsgType > 4 {
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
}
return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
}

View file

@ -0,0 +1,488 @@
package awg
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewMagicHeaderSameValue(t *testing.T) {
tests := []struct {
name string
value uint32
expected MagicHeader
}{
{
name: "zero value",
value: 0,
expected: MagicHeader{Min: 0, Max: 0},
},
{
name: "small value",
value: 1,
expected: MagicHeader{Min: 1, Max: 1},
},
{
name: "large value",
value: 4294967295, // max uint32
expected: MagicHeader{Min: 4294967295, Max: 4294967295},
},
{
name: "medium value",
value: 1000,
expected: MagicHeader{Min: 1000, Max: 1000},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := NewMagicHeaderSameValue(tt.value)
require.Equal(t, tt.expected, result)
})
}
}
func TestNewMagicHeader(t *testing.T) {
tests := []struct {
name string
min uint32
max uint32
expected MagicHeader
errorMsg string
}{
{
name: "valid range",
min: 1,
max: 10,
expected: MagicHeader{Min: 1, Max: 10},
},
{
name: "equal values",
min: 5,
max: 5,
expected: MagicHeader{Min: 5, Max: 5},
},
{
name: "zero range",
min: 0,
max: 0,
expected: MagicHeader{Min: 0, Max: 0},
},
{
name: "max uint32 range",
min: 4294967294,
max: 4294967295,
expected: MagicHeader{Min: 4294967294, Max: 4294967295},
},
{
name: "min greater than max",
min: 10,
max: 5,
expected: MagicHeader{},
errorMsg: "min (10) cannot be greater than max (5)",
},
{
name: "large min greater than max",
min: 4294967295,
max: 1,
expected: MagicHeader{},
errorMsg: "min (4294967295) cannot be greater than max (1)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := NewMagicHeader(tt.min, tt.max)
if tt.errorMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Equal(t, MagicHeader{}, result)
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, result)
}
})
}
}
func TestParseMagicHeader(t *testing.T) {
tests := []struct {
name string
key string
value string
expected MagicHeader
errorMsg string
}{
{
name: "single value",
key: "header1",
value: "100",
expected: MagicHeader{Min: 100, Max: 100},
},
{
name: "valid range",
key: "header2",
value: "10-20",
expected: MagicHeader{Min: 10, Max: 20},
},
{
name: "zero single value",
key: "header3",
value: "0",
expected: MagicHeader{Min: 0, Max: 0},
},
{
name: "zero range",
key: "header4",
value: "0-0",
expected: MagicHeader{Min: 0, Max: 0},
},
{
name: "max uint32 single",
key: "header5",
value: "4294967295",
expected: MagicHeader{Min: 4294967295, Max: 4294967295},
},
{
name: "max uint32 range",
key: "header6",
value: "4294967294-4294967295",
expected: MagicHeader{Min: 4294967294, Max: 4294967295},
},
{
name: "invalid single value - not number",
key: "header7",
value: "abc",
expected: MagicHeader{},
errorMsg: "parse key: header7; value: abc;",
},
{
name: "invalid single value - negative",
key: "header8",
value: "-5",
expected: MagicHeader{},
errorMsg: "invalid value for key: header8; value: -5;",
},
{
name: "invalid single value - too large",
key: "header9",
value: "4294967296",
expected: MagicHeader{},
errorMsg: "parse key: header9; value: 4294967296;",
},
{
name: "invalid range - min not number",
key: "header10",
value: "abc-10",
expected: MagicHeader{},
errorMsg: "parse min key: header10; value: abc;",
},
{
name: "invalid range - max not number",
key: "header11",
value: "10-abc",
expected: MagicHeader{},
errorMsg: "parse max key: header11; value: abc;",
},
{
name: "invalid range - min greater than max",
key: "header12",
value: "20-10",
expected: MagicHeader{},
errorMsg: "new magicHeader key: header12; value: 20-10;",
},
{
name: "invalid range - too many parts",
key: "header13",
value: "10-20-30",
expected: MagicHeader{},
errorMsg: "parse key: header13; value: 10-20-30;",
},
{
name: "empty value",
key: "header14",
value: "",
expected: MagicHeader{},
errorMsg: "parse key: header14; value: ;",
},
{
name: "hyphen only",
key: "header15",
value: "-",
expected: MagicHeader{},
errorMsg: "invalid value for key: header15; value: -;",
},
{
name: "empty min",
key: "header16",
value: "-10",
expected: MagicHeader{},
errorMsg: "invalid value for key: header16; value: -10;",
},
{
name: "empty max",
key: "header17",
value: "10-",
expected: MagicHeader{},
errorMsg: "invalid value for key: header17; value: 10-;",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := ParseMagicHeader(tt.key, tt.value)
if tt.errorMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Equal(t, MagicHeader{}, result)
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, result)
}
})
}
}
func TestNewMagicHeaders(t *testing.T) {
tests := []struct {
name string
magicHeaders []MagicHeader
errorMsg string
}{
{
name: "valid non-overlapping headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 11, Max: 20},
{Min: 21, Max: 30},
{Min: 31, Max: 40},
},
},
{
name: "valid adjacent headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 1},
{Min: 2, Max: 2},
{Min: 3, Max: 3},
{Min: 4, Max: 4},
},
},
{
name: "valid zero-based headers",
magicHeaders: []MagicHeader{
{Min: 0, Max: 0},
{Min: 1, Max: 1},
{Min: 2, Max: 2},
{Min: 3, Max: 3},
},
},
{
name: "valid large value headers",
magicHeaders: []MagicHeader{
{Min: 4294967290, Max: 4294967291},
{Min: 4294967292, Max: 4294967293},
{Min: 4294967294, Max: 4294967294},
{Min: 4294967295, Max: 4294967295},
},
},
{
name: "too few headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 11, Max: 20},
{Min: 21, Max: 30},
},
errorMsg: "all header types should be included:",
},
{
name: "too many headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 11, Max: 20},
{Min: 21, Max: 30},
{Min: 31, Max: 40},
{Min: 41, Max: 50},
},
errorMsg: "all header types should be included:",
},
{
name: "empty headers",
magicHeaders: []MagicHeader{},
errorMsg: "all header types should be included:",
},
{
name: "overlapping headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 15},
{Min: 10, Max: 20},
{Min: 25, Max: 30},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
{
name: "overlapping headers at limit-first",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 10, Max: 20},
{Min: 25, Max: 30},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
{
name: "overlapping headers at limit-second",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 15, Max: 25},
{Min: 25, Max: 30},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
{
name: "overlapping headers at limit-third",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 15, Max: 25},
{Min: 30, Max: 35},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
{
name: "identical ranges",
magicHeaders: []MagicHeader{
{Min: 10, Max: 20},
{Min: 10, Max: 20},
{Min: 25, Max: 30},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := NewMagicHeaders(tt.magicHeaders)
if tt.errorMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Equal(t, MagicHeaders{}, result)
} else {
require.NoError(t, err)
require.Equal(t, tt.magicHeaders, result.Values)
require.NotNil(t, result.randomGenerator)
}
})
}
}
// Mock PRNG for testing
type mockPRNG struct {
returnValue uint32
}
func (m *mockPRNG) RandomSizeInRange(min, max uint32) uint32 {
return m.returnValue
}
func (m *mockPRNG) Get() uint64 {
return 0
}
func (m *mockPRNG) ReadSize(size int) []byte {
return make([]byte, size)
}
func TestMagicHeaders_Get(t *testing.T) {
// Create test headers
headers := []MagicHeader{
{Min: 1, Max: 10},
{Min: 11, Max: 20},
{Min: 21, Max: 30},
{Min: 31, Max: 40},
}
tests := []struct {
name string
defaultMsgType uint32
mockValue uint32
expectedValue uint32
errorMsg string
}{
{
name: "valid type 1",
defaultMsgType: 1,
mockValue: 5,
expectedValue: 5,
},
{
name: "valid type 2",
defaultMsgType: 2,
mockValue: 15,
expectedValue: 15,
},
{
name: "valid type 3",
defaultMsgType: 3,
mockValue: 25,
expectedValue: 25,
},
{
name: "valid type 4",
defaultMsgType: 4,
mockValue: 35,
expectedValue: 35,
},
{
name: "invalid type 0",
defaultMsgType: 0,
mockValue: 0,
expectedValue: 0,
errorMsg: "invalid msg type: 0",
},
{
name: "invalid type 5",
defaultMsgType: 5,
mockValue: 0,
expectedValue: 0,
errorMsg: "invalid msg type: 5",
},
{
name: "invalid type max uint32",
defaultMsgType: 4294967295,
mockValue: 0,
expectedValue: 0,
errorMsg: "invalid msg type: 4294967295",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a new instance with mock PRNG for each test
testMagicHeaders := MagicHeaders{
Values: headers,
randomGenerator: &mockPRNG{returnValue: tt.mockValue},
}
result, err := testMagicHeaders.Get(tt.defaultMsgType)
if tt.errorMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Equal(t, uint32(0), result)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedValue, result)
}
})
}
}

50
device/awg/prng.go Normal file
View file

@ -0,0 +1,50 @@
package awg
import (
crand "crypto/rand"
v2 "math/rand/v2"
"golang.org/x/exp/constraints"
)
type RandomNumberGenerator[T constraints.Integer] interface {
RandomSizeInRange(min, max T) T
Get() uint64
ReadSize(size int) []byte
}
type PRNG[T constraints.Integer] struct {
cha8Rand *v2.ChaCha8
}
func NewPRNG[T constraints.Integer]() PRNG[T] {
buf := make([]byte, 32)
_, _ = crand.Read(buf)
return PRNG[T]{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
}
}
func (p PRNG[T]) RandomSizeInRange(min, max T) T {
if min > max {
panic("min must be less than max")
}
if min == max {
return min
}
return T(p.Get()%uint64(max-min)) + min
}
func (p PRNG[T]) Get() uint64 {
return p.cha8Rand.Uint64()
}
func (p PRNG[T]) ReadSize(size int) []byte {
// TODO: use a memory pool to allocate
buf := make([]byte, size)
_, _ = p.cha8Rand.Read(buf)
return buf
}

View file

@ -55,7 +55,7 @@ func parseTag(input string) (Tag, error) {
return tag, nil return tag, nil
} }
func Parse(name, input string) (TagJunkPacketGenerator, error) { func ParseTagJunkGenerator(name, input string) (TagJunkPacketGenerator, error) {
inputSlice := strings.Split(input, "<") inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 { if len(inputSlice) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input) return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)

View file

@ -64,7 +64,7 @@ func TestParse(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := Parse(tt.args.name, tt.args.input) _, err := ParseTagJunkGenerator(tt.args.name, tt.args.input)
// TODO: ErrorAs doesn't work as you think // TODO: ErrorAs doesn't work as you think
if tt.wantErr != nil { if tt.wantErr != nil {

View file

@ -118,6 +118,7 @@ func (st *CookieChecker) CreateReply(
msg []byte, msg []byte,
recv uint32, recv uint32,
src []byte, src []byte,
msgType uint32,
) (*MessageCookieReply, error) { ) (*MessageCookieReply, error) {
st.RLock() st.RLock()
@ -153,7 +154,7 @@ func (st *CookieChecker) CreateReply(
smac1 := smac2 - blake2s.Size128 smac1 := smac2 - blake2s.Size128
reply := new(MessageCookieReply) reply := new(MessageCookieReply)
reply.Type = MessageCookieReplyType reply.Type = msgType
reply.Receiver = recv reply.Receiver = recv
_, err := rand.Read(reply.Nonce[:]) _, err := rand.Read(reply.Nonce[:])

View file

@ -99,7 +99,7 @@ func TestCookieMAC1(t *testing.T) {
0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
} }
generator.AddMacs(msg) generator.AddMacs(msg)
reply, err := checker.CreateReply(msg, 1377, src) reply, err := checker.CreateReply(msg, 1377, src, DefaultMessageCookieReplyType)
if err != nil { if err != nil {
t.Fatal("Failed to create cookie reply:", err) t.Fatal("Failed to create cookie reply:", err)
} }

View file

@ -6,7 +6,9 @@
package device package device
import ( import (
"encoding/binary"
"errors" "errors"
"fmt"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -578,6 +580,7 @@ func (device *Device) BindClose() error {
device.net.Unlock() device.net.Unlock()
return err return err
} }
func (device *Device) isAWG() bool { func (device *Device) isAWG() bool {
return device.version >= VersionAwg return device.version >= VersionAwg
} }
@ -591,171 +594,123 @@ func (device *Device) resetProtocol() {
} }
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet { if !tempAwg.Cfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
return nil return nil
} }
var errs []error var errs []error
isASecOn := false isAwgOn := false
device.awg.ASecMux.Lock() device.awg.Mux.Lock()
if tempAwg.ASecCfg.JunkPacketCount < 0 { if tempAwg.Cfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf( errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative", "JunkPacketCount should be non negative",
), ),
) )
} }
device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
if tempAwg.ASecCfg.JunkPacketCount != 0 { if tempAwg.Cfg.JunkPacketCount != 0 {
isASecOn = true isAwgOn = true
} }
device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
if tempAwg.ASecCfg.JunkPacketMinSize != 0 { if tempAwg.Cfg.JunkPacketMinSize != 0 {
isASecOn = true isAwgOn = true
} }
if device.awg.ASecCfg.JunkPacketCount > 0 && if device.awg.Cfg.JunkPacketCount > 0 &&
tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize { tempAwg.Cfg.JunkPacketMaxSize == tempAwg.Cfg.JunkPacketMinSize {
tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work tempAwg.Cfg.JunkPacketMaxSize++ // to make rand gen work
} }
if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize { if tempAwg.Cfg.JunkPacketMaxSize >= MaxSegmentSize {
device.awg.ASecCfg.JunkPacketMinSize = 0 device.awg.Cfg.JunkPacketMinSize = 0
device.awg.ASecCfg.JunkPacketMaxSize = 1 device.awg.Cfg.JunkPacketMaxSize = 1
errs = append(errs, ipcErrorf( errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempAwg.ASecCfg.JunkPacketMaxSize, tempAwg.Cfg.JunkPacketMaxSize,
MaxSegmentSize, MaxSegmentSize,
)) ))
} else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize { } else if tempAwg.Cfg.JunkPacketMaxSize < tempAwg.Cfg.JunkPacketMinSize {
errs = append(errs, ipcErrorf( errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d", "maxSize: %d; should be greater than minSize: %d",
tempAwg.ASecCfg.JunkPacketMaxSize, tempAwg.Cfg.JunkPacketMaxSize,
tempAwg.ASecCfg.JunkPacketMinSize, tempAwg.Cfg.JunkPacketMinSize,
)) ))
} else { } else {
device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize device.awg.Cfg.JunkPacketMaxSize = tempAwg.Cfg.JunkPacketMaxSize
} }
if tempAwg.ASecCfg.JunkPacketMaxSize != 0 { if tempAwg.Cfg.JunkPacketMaxSize != 0 {
isASecOn = true isAwgOn = true
} }
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize magicHeaders := make([]awg.MagicHeader, 4)
if newInitSize >= MaxSegmentSize { if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
errs = append(errs, ipcErrorf( return ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, "magic headers should have 4 values; got: %d",
tempAwg.ASecCfg.InitHeaderJunkSize, len(tempAwg.Cfg.MagicHeaders.Values),
MaxSegmentSize,
),
) )
} else {
device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
} }
if tempAwg.ASecCfg.InitHeaderJunkSize != 0 { if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
isASecOn = true isAwgOn = true
}
newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
if newResponseSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.ResponseHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
}
if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
isASecOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
if newCookieSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
}
if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
isASecOn = true
}
newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
if newTransportSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.TransportHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
}
if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
isASecOn = true
}
if tempAwg.ASecCfg.InitPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header") device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader
MessageInitiationType = magicHeaders[0].Min
} else { } else {
device.log.Verbosef("UAPI: Using default init type") device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType MessageInitiationType = DefaultMessageInitiationType
magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
} }
if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 { if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
isASecOn = true isAwgOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header") device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader MessageResponseType = magicHeaders[1].Min
} else { } else {
device.log.Verbosef("UAPI: Using default response type") device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType MessageResponseType = DefaultMessageResponseType
magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
} }
if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 { if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
isASecOn = true isAwgOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header") device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader MessageCookieReplyType = magicHeaders[2].Min
} else { } else {
device.log.Verbosef("UAPI: Using default underload type") device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType MessageCookieReplyType = DefaultMessageCookieReplyType
magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
} }
if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 { if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
isASecOn = true isAwgOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header") device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader MessageTransportType = magicHeaders[3].Min
} else { } else {
device.log.Verbosef("UAPI: Using default transport type") device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType MessageTransportType = DefaultMessageTransportType
magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
}
var err error
device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
if err != nil {
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
} }
isSameHeaderMap := map[uint32]struct{}{ isSameHeaderMap := map[uint32]struct{}{
@ -778,6 +733,78 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
) )
} }
newInitSize := MessageInitiationSize + tempAwg.Cfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.Cfg.InitHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.Cfg.InitHeaderJunkSize = tempAwg.Cfg.InitHeaderJunkSize
}
if tempAwg.Cfg.InitHeaderJunkSize != 0 {
isAwgOn = true
}
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
if newResponseSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.Cfg.ResponseHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.Cfg.ResponseHeaderJunkSize = tempAwg.Cfg.ResponseHeaderJunkSize
}
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
isAwgOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
if newCookieSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.Cfg.CookieReplyHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.Cfg.CookieReplyHeaderJunkSize = tempAwg.Cfg.CookieReplyHeaderJunkSize
}
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
isAwgOn = true
}
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
if newTransportSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.Cfg.TransportHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.Cfg.TransportHeaderJunkSize = tempAwg.Cfg.TransportHeaderJunkSize
}
if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
isAwgOn = true
}
isSameSizeMap := map[int]struct{}{ isSameSizeMap := map[int]struct{}{
newInitSize: {}, newInitSize: {},
newResponseSize: {}, newResponseSize: {},
@ -797,10 +824,10 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
) )
} else { } else {
msgTypeToJunkSize = map[uint32]int{ msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize, MessageInitiationType: device.awg.Cfg.InitHeaderJunkSize,
MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize, MessageResponseType: device.awg.Cfg.ResponseHeaderJunkSize,
MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize, MessageCookieReplyType: device.awg.Cfg.CookieReplyHeaderJunkSize,
MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize, MessageTransportType: device.awg.Cfg.TransportHeaderJunkSize,
} }
packetSizeToMsgType = map[int]uint32{ packetSizeToMsgType = map[int]uint32{
@ -811,12 +838,8 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
} }
} }
device.awg.IsASecOn.SetTo(isASecOn) device.awg.IsOn.SetTo(isAwgOn)
var err error device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)
if err != nil {
errs = append(errs, err)
}
if tempAwg.HandshakeHandler.IsSet { if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil { if err := tempAwg.HandshakeHandler.Validate(); err != nil {
@ -824,15 +847,92 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
ipc.IpcErrorInvalid, "handshake handler validate: %w", err)) ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else { } else {
device.awg.HandshakeHandler = tempAwg.HandshakeHandler device.awg.HandshakeHandler = tempAwg.HandshakeHandler
device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
device.version = VersionAwgSpecialHandshake device.version = VersionAwgSpecialHandshake
} }
} else { } else {
device.version = VersionAwg device.version = VersionAwg
} }
device.awg.ASecMux.Unlock() device.awg.Mux.Unlock()
return errors.Join(errs...) return errors.Join(errs...)
} }
func (device *Device) ProcessAWGPacket(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
// TODO:
// if awg.WaitResponse.ShouldWait.IsSet() {
// awg.WaitResponse.Channel <- struct{}{}
// }
expectedMsgType, isKnownSize := packetSizeToMsgType[size]
if !isKnownSize {
msgType, err := device.handleTransport(size, packet, buffer)
if err != nil {
return 0, fmt.Errorf("handle transport: %w", err)
}
return msgType, nil
}
junkSize := msgTypeToJunkSize[expectedMsgType]
// transport size can align with other header types;
// making sure we have the right actualMsgType
actualMsgType, err := device.getMsgType(packet, junkSize)
if err != nil {
return 0, fmt.Errorf("get msg type: %w", err)
}
if actualMsgType == expectedMsgType {
*packet = (*packet)[junkSize:]
return actualMsgType, nil
}
device.log.Verbosef("awg: transport packet lined up with another msg type")
msgType, err := device.handleTransport(size, packet, buffer)
if err != nil {
return 0, fmt.Errorf("handle transport: %w", err)
}
return msgType, nil
}
func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
msgTypeValue := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeValue)
if err != nil {
return 0, fmt.Errorf("get magic header min: %w", err)
}
return msgType, nil
}
func (device *Device) handleTransport(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
junkSize := device.awg.Cfg.TransportHeaderJunkSize
msgType, err := device.getMsgType(packet, junkSize)
if err != nil {
return 0, fmt.Errorf("get msg type: %w", err)
}
if msgType != MessageTransportType {
// probably a junk packet
return 0, fmt.Errorf("Received message with unknown type: %d", msgType)
}
if junkSize > 0 {
// remove junk from buffer by shifting the packet
// this buffer is also used for decryption, so it needs to be corrected
copy((*buffer)[:size], (*packet)[junkSize:])
size -= junkSize
// need to reinitialize packet as well
(*packet) = (*packet)[:size]
}
return msgType, nil
}

View file

@ -232,14 +232,14 @@ func TestAWGDevicePing(t *testing.T) {
"jc", "5", "jc", "5",
"jmin", "500", "jmin", "500",
"jmax", "1000", "jmax", "1000",
"s1", "30", "s1", "15",
"s2", "40", "s2", "18",
"s3", "50", "s3", "20",
"s4", "5", "s4", "25",
"h1", "123456", "h1", "123456-123500",
"h2", "67543", "h2", "67543-67550",
"h3", "123123", "h3", "123123-123200",
"h4", "32345", "h4", "32345-32350",
) )
t.Run("ping 1.0.0.1", func(t *testing.T) { t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil) pair.Send(t, Ping, nil)

View file

@ -205,12 +205,18 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
device.awg.ASecMux.RLock() device.awg.Mux.RLock()
msgType, err := device.awg.GetMsgType(DefaultMessageInitiationType)
if err != nil {
device.awg.Mux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
}
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: msgType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
} }
device.awg.ASecMux.RUnlock() device.awg.Mux.RUnlock()
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
@ -264,12 +270,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte chainKey [blake2s.Size]byte
) )
device.awg.ASecMux.RLock() device.awg.Mux.RLock()
if msg.Type != MessageInitiationType { if msg.Type != MessageInitiationType {
device.awg.ASecMux.RUnlock() device.awg.Mux.RUnlock()
return nil return nil
} }
device.awg.ASecMux.RUnlock() device.awg.Mux.RUnlock()
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -384,9 +391,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
var msg MessageResponse var msg MessageResponse
device.awg.ASecMux.RLock() device.awg.Mux.RLock()
msg.Type = MessageResponseType msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
device.awg.ASecMux.RUnlock() if err != nil {
device.awg.Mux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
}
device.awg.Mux.RUnlock()
msg.Sender = handshake.localIndex msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex msg.Receiver = handshake.remoteIndex
@ -436,12 +448,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.awg.ASecMux.RLock() device.awg.Mux.RLock()
if msg.Type != MessageResponseType { if msg.Type != MessageResponseType {
device.awg.ASecMux.RUnlock() device.awg.Mux.RUnlock()
return nil return nil
} }
device.awg.ASecMux.RUnlock() device.awg.Mux.RUnlock()
// lookup handshake by receiver // lookup handshake by receiver

View file

@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
} }
deathSpiral = 0 deathSpiral = 0
device.awg.ASecMux.RLock() device.awg.Mux.RLock()
// handle each packet in the batch // handle each packet in the batch
for i, size := range sizes[:count] { for i, size := range sizes[:count] {
if size < MinMessageSize { if size < MinMessageSize {
@ -140,37 +140,11 @@ func (device *Device) RoutineReceiveIncoming(
packet := bufsArrs[i][:size] packet := bufsArrs[i][:size]
var msgType uint32 var msgType uint32
if device.isAWG() { if device.isAWG() {
// TODO: msgType, err = device.ProcessAWGPacket(size, &packet, bufsArrs[i])
// if awg.WaitResponse.ShouldWait.IsSet() {
// awg.WaitResponse.Channel <- struct{}{}
// }
if assumedMsgType, ok := packetSizeToMsgType[size]; ok { if err != nil {
junkSize := msgTypeToJunkSize[assumedMsgType] device.log.Verbosef("awg: process packet: %v", err)
// transport size can align with other header types; continue
// making sure we have the right msgType
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
if msgType == assumedMsgType {
packet = packet[junkSize:]
} else {
device.log.Verbosef("transport packet lined up with another msg type")
msgType = binary.LittleEndian.Uint32(packet[:4])
}
} else {
transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4])
if msgType != MessageTransportType {
// probably a junk packet
device.log.Verbosef("aSec: Received message with unknown type: %d", msgType)
continue
}
// remove junk from bufsArrs by shifting the packet
// this buffer is also used for decryption, so it needs to be corrected
copy(bufsArrs[i][:size], packet[transportJunkSize:])
size -= transportJunkSize
// need to reinitialize packet as well
packet = packet[:size]
} }
} else { } else {
msgType = binary.LittleEndian.Uint32(packet[:4]) msgType = binary.LittleEndian.Uint32(packet[:4])
@ -259,7 +233,7 @@ func (device *Device) RoutineReceiveIncoming(
default: default:
} }
} }
device.awg.ASecMux.RUnlock() device.awg.Mux.RUnlock()
for peer, elemsContainer := range elemsByPeer { for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() { if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer peer.queue.inbound.c <- elemsContainer
@ -318,7 +292,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c { for elem := range device.queue.handshake.c {
device.awg.ASecMux.RLock() device.awg.Mux.RLock()
// handle cookie fields and ratelimiting // handle cookie fields and ratelimiting
@ -405,6 +379,9 @@ func (device *Device) RoutineHandshake(id int) {
goto skip goto skip
} }
// have to reassign msgType for ranged msgType to work
msg.Type = elem.msgType
// consume initiation // consume initiation
peer := device.ConsumeMessageInitiation(&msg) peer := device.ConsumeMessageInitiation(&msg)
if peer == nil { if peer == nil {
@ -437,6 +414,9 @@ func (device *Device) RoutineHandshake(id int) {
goto skip goto skip
} }
// have to reassign msgType for ranged msgType to work
msg.Type = elem.msgType
// consume response // consume response
peer := device.ConsumeMessageResponse(&msg) peer := device.ConsumeMessageResponse(&msg)
@ -470,7 +450,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive() peer.SendKeepalive()
} }
skip: skip:
device.awg.ASecMux.RUnlock() device.awg.Mux.RUnlock()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
} }
} }

View file

@ -130,7 +130,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if peer.device.version >= VersionAwg { if peer.device.version >= VersionAwg {
var junks [][]byte var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake { if peer.device.version == VersionAwgSpecialHandshake {
peer.device.awg.ASecMux.RLock() peer.device.awg.Mux.RLock()
// set junks depending on packet type // set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil { if junks == nil {
@ -141,18 +141,13 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
} else { } else {
peer.device.log.Verbosef("%v - Special junks sent", peer) peer.device.log.Verbosef("%v - Special junks sent", peer)
} }
peer.device.awg.ASecMux.RUnlock() peer.device.awg.Mux.RUnlock()
} else { } else {
junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount) junks = make([][]byte, 0, peer.device.awg.Cfg.JunkPacketCount)
}
peer.device.awg.ASecMux.RLock()
err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.ASecMux.RUnlock()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
} }
peer.device.awg.Mux.RLock()
peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.Mux.RUnlock()
if len(junks) > 0 { if len(junks) > 0 {
err = peer.SendBuffers(junks) err = peer.SendBuffers(junks)
@ -242,10 +237,17 @@ func (device *Device) SendHandshakeCookie(
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
msgType, err := device.awg.GetMsgType(DefaultMessageCookieReplyType)
if err != nil {
device.log.Errorf("Get message type for cookie reply: %v", err)
return err
}
reply, err := device.cookieChecker.CreateReply( reply, err := device.cookieChecker.CreateReply(
initiatingElem.packet, initiatingElem.packet,
sender, sender,
initiatingElem.endpoint.DstToBytes(), initiatingElem.endpoint.DstToBytes(),
msgType,
) )
if err != nil { if err != nil {
device.log.Errorf("Failed to create cookie reply: %v", err) device.log.Errorf("Failed to create cookie reply: %v", err)
@ -528,7 +530,12 @@ func (device *Device) RoutineEncryption(id int) {
fieldReceiver := header[4:8] fieldReceiver := header[4:8]
fieldNonce := header[8:16] fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType) msgType, err := device.awg.GetMsgType(DefaultMessageTransportType)
if err != nil {
device.log.Errorf("get message type for transport: %v", err)
continue
}
binary.LittleEndian.PutUint32(fieldType, msgType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)

View file

@ -99,38 +99,36 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
} }
if device.isAWG() { if device.isAWG() {
if device.awg.ASecCfg.JunkPacketCount != 0 { if device.awg.Cfg.JunkPacketCount != 0 {
sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount) sendf("jc=%d", device.awg.Cfg.JunkPacketCount)
} }
if device.awg.ASecCfg.JunkPacketMinSize != 0 { if device.awg.Cfg.JunkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize) sendf("jmin=%d", device.awg.Cfg.JunkPacketMinSize)
} }
if device.awg.ASecCfg.JunkPacketMaxSize != 0 { if device.awg.Cfg.JunkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize) sendf("jmax=%d", device.awg.Cfg.JunkPacketMaxSize)
} }
if device.awg.ASecCfg.InitHeaderJunkSize != 0 { if device.awg.Cfg.InitHeaderJunkSize != 0 {
sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize) sendf("s1=%d", device.awg.Cfg.InitHeaderJunkSize)
} }
if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 { if device.awg.Cfg.ResponseHeaderJunkSize != 0 {
sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize) sendf("s2=%d", device.awg.Cfg.ResponseHeaderJunkSize)
} }
if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 { if device.awg.Cfg.CookieReplyHeaderJunkSize != 0 {
sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize) sendf("s3=%d", device.awg.Cfg.CookieReplyHeaderJunkSize)
} }
if device.awg.ASecCfg.TransportHeaderJunkSize != 0 { if device.awg.Cfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize) sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
} }
if device.awg.ASecCfg.InitPacketMagicHeader != 0 { for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) if magicHeader.Min > 4 {
} if magicHeader.Min == magicHeader.Max {
if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { sendf("h%d=%d", i+1, magicHeader.Min)
sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) continue
} }
if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
} }
if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
} }
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields() specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
@ -200,6 +198,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
deviceConfig := true deviceConfig := true
tempAwg := awg.Protocol{} tempAwg := awg.Protocol{}
tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@ -312,8 +312,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_count") device.log.Verbosef("UAPI: Updating junk_packet_count")
tempAwg.ASecCfg.JunkPacketCount = junkPacketCount tempAwg.Cfg.JunkPacketCount = junkPacketCount
tempAwg.ASecCfg.IsSet = true tempAwg.Cfg.IsSet = true
case "jmin": case "jmin":
junkPacketMinSize, err := strconv.Atoi(value) junkPacketMinSize, err := strconv.Atoi(value)
@ -321,8 +321,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_min_size") device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize tempAwg.Cfg.JunkPacketMinSize = junkPacketMinSize
tempAwg.ASecCfg.IsSet = true tempAwg.Cfg.IsSet = true
case "jmax": case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value) junkPacketMaxSize, err := strconv.Atoi(value)
@ -330,8 +330,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_max_size") device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize tempAwg.Cfg.JunkPacketMaxSize = junkPacketMaxSize
tempAwg.ASecCfg.IsSet = true tempAwg.Cfg.IsSet = true
case "s1": case "s1":
initPacketJunkSize, err := strconv.Atoi(value) initPacketJunkSize, err := strconv.Atoi(value)
@ -339,8 +339,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
} }
device.log.Verbosef("UAPI: Updating init_packet_junk_size") device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize tempAwg.Cfg.InitHeaderJunkSize = initPacketJunkSize
tempAwg.ASecCfg.IsSet = true tempAwg.Cfg.IsSet = true
case "s2": case "s2":
responsePacketJunkSize, err := strconv.Atoi(value) responsePacketJunkSize, err := strconv.Atoi(value)
@ -348,8 +348,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
} }
device.log.Verbosef("UAPI: Updating response_packet_junk_size") device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize tempAwg.Cfg.ResponseHeaderJunkSize = responsePacketJunkSize
tempAwg.ASecCfg.IsSet = true tempAwg.Cfg.IsSet = true
case "s3": case "s3":
cookieReplyPacketJunkSize, err := strconv.Atoi(value) cookieReplyPacketJunkSize, err := strconv.Atoi(value)
@ -357,8 +357,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
} }
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size") device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize tempAwg.Cfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
tempAwg.ASecCfg.IsSet = true tempAwg.Cfg.IsSet = true
case "s4": case "s4":
transportPacketJunkSize, err := strconv.Atoi(value) transportPacketJunkSize, err := strconv.Atoi(value)
@ -366,47 +366,47 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
} }
device.log.Verbosef("UAPI: Updating transport_packet_junk_size") device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.ASecCfg.IsSet = true tempAwg.Cfg.IsSet = true
case "h1": case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) initMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
tempAwg.Cfg.IsSet = true
case "h2": case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) responseMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
tempAwg.Cfg.IsSet = true
case "h3": case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true
tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
tempAwg.Cfg.IsSet = true
case "h4": case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) transportMagicHeader, err := awg.ParseMagicHeader(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
} }
tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempAwg.ASecCfg.IsSet = true tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
tempAwg.Cfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5": case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 { if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key) device.log.Verbosef("UAPI: received empty %s", key)
return nil return nil
} }
generators, err := awg.Parse(key, value) generators, err := awg.ParseTagJunkGenerator(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
} }
@ -419,7 +419,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return nil return nil
} }
generators, err := awg.Parse(key, value) generators, err := awg.ParseTagJunkGenerator(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
} }

2
go.mod
View file

@ -5,9 +5,9 @@ go 1.24.4
require ( require (
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
github.com/tevino/abool v1.2.0 github.com/tevino/abool v1.2.0
github.com/tevino/abool/v2 v2.1.0
go.uber.org/atomic v1.11.0 go.uber.org/atomic v1.11.0
golang.org/x/crypto v0.39.0 golang.org/x/crypto v0.39.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.41.0 golang.org/x/net v0.41.0
golang.org/x/sys v0.33.0 golang.org/x/sys v0.33.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2

14
go.sum
View file

@ -2,24 +2,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA= github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg= github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
@ -34,7 +28,3 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE= gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE=
gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk= gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk=
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM=
gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=