From be20e770772a595953065b270fc7eb973cdbeaf9 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Thu, 10 Jul 2025 05:56:35 +0200
Subject: [PATCH] chore: magic header tests
---
device/awg/awg.go | 6 +-
device/awg/magic_header.go | 24 +-
device/awg/magic_header_test.go | 488 ++++++++++++++++++++++++++++++++
device/awg/{util.go => prng.go} | 6 +
device/device.go | 2 +-
5 files changed, 511 insertions(+), 15 deletions(-)
create mode 100644 device/awg/magic_header_test.go
rename device/awg/{util.go => prng.go} (84%)
diff --git a/device/awg/awg.go b/device/awg/awg.go
index 6138c5f..336617e 100644
--- a/device/awg/awg.go
+++ b/device/awg/awg.go
@@ -82,14 +82,14 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte,
return junk, nil
}
-func (protocol *Protocol) GetLimitMin(msgTypeRange uint32) (uint32, error) {
- for _, limit := range protocol.MagicHeaders.headers {
+func (protocol *Protocol) GetMagicHeaderMinFor(msgTypeRange uint32) (uint32, error) {
+ for _, limit := range protocol.MagicHeaders.headerValues {
if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max {
return limit.Min, nil
}
}
- return 0, fmt.Errorf("no limit for range: %d", msgTypeRange)
+ return 0, fmt.Errorf("no header for range: %d", msgTypeRange)
}
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
diff --git a/device/awg/magic_header.go b/device/awg/magic_header.go
index f7ad45c..df8bdf1 100644
--- a/device/awg/magic_header.go
+++ b/device/awg/magic_header.go
@@ -35,6 +35,8 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) {
}
return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
+ } else if len(splitLimits[0]) == 0 || len(splitLimits[1]) == 0 {
+ return MagicHeader{}, fmt.Errorf("invalid value for key: %s; value: %s; expected format: min-max", key, value)
}
min, err := strconv.ParseUint(splitLimits[0], 10, 32)
@@ -56,30 +58,30 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) {
}
type MagicHeaders struct {
- headers []MagicHeader
- randomGenerator PRNG[uint32]
+ headerValues []MagicHeader
+ randomGenerator RandomNumberGenerator[uint32]
}
-func NewMagicHeaders(magicHeaders []MagicHeader) (MagicHeaders, error) {
- if len(magicHeaders) != 4 {
- return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", magicHeaders)
+func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
+ if len(headerValues) != 4 {
+ return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", headerValues)
}
- sortedMagicHeaders := slices.SortedFunc(slices.Values(magicHeaders), func(lhs MagicHeader, rhs MagicHeader) int {
+ sortedMagicHeaders := slices.SortedFunc(slices.Values(headerValues), func(lhs MagicHeader, rhs MagicHeader) int {
return cmp.Compare(lhs.Min, rhs.Min)
})
for i := range 3 {
- if sortedMagicHeaders[i].Min > sortedMagicHeaders[i+1].Min {
+ if sortedMagicHeaders[i].Max >= sortedMagicHeaders[i+1].Min {
return MagicHeaders{}, fmt.Errorf(
"magic headers shouldn't overlap; %v > %v",
- sortedMagicHeaders[i-1].Min,
- sortedMagicHeaders[i].Min,
+ sortedMagicHeaders[i].Max,
+ sortedMagicHeaders[i+1].Min,
)
}
}
- return MagicHeaders{headers: magicHeaders, randomGenerator: NewPRNG[uint32]()}, nil
+ return MagicHeaders{headerValues: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
}
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
@@ -87,5 +89,5 @@ func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
}
- return mh.randomGenerator.RandomSizeInRange(mh.headers[defaultMsgType-1].Min, mh.headers[defaultMsgType-1].Max), nil
+ return mh.randomGenerator.RandomSizeInRange(mh.headerValues[defaultMsgType-1].Min, mh.headerValues[defaultMsgType-1].Max), nil
}
diff --git a/device/awg/magic_header_test.go b/device/awg/magic_header_test.go
new file mode 100644
index 0000000..d9e608c
--- /dev/null
+++ b/device/awg/magic_header_test.go
@@ -0,0 +1,488 @@
+package awg
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewMagicHeaderSameValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value uint32
+ expected MagicHeader
+ }{
+ {
+ name: "zero value",
+ value: 0,
+ expected: MagicHeader{Min: 0, Max: 0},
+ },
+ {
+ name: "small value",
+ value: 1,
+ expected: MagicHeader{Min: 1, Max: 1},
+ },
+ {
+ name: "large value",
+ value: 4294967295, // max uint32
+ expected: MagicHeader{Min: 4294967295, Max: 4294967295},
+ },
+ {
+ name: "medium value",
+ value: 1000,
+ expected: MagicHeader{Min: 1000, Max: 1000},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result := NewMagicHeaderSameValue(tt.value)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestNewMagicHeader(t *testing.T) {
+ tests := []struct {
+ name string
+ min uint32
+ max uint32
+ expected MagicHeader
+ errorMsg string
+ }{
+ {
+ name: "valid range",
+ min: 1,
+ max: 10,
+ expected: MagicHeader{Min: 1, Max: 10},
+ },
+ {
+ name: "equal values",
+ min: 5,
+ max: 5,
+ expected: MagicHeader{Min: 5, Max: 5},
+ },
+ {
+ name: "zero range",
+ min: 0,
+ max: 0,
+ expected: MagicHeader{Min: 0, Max: 0},
+ },
+ {
+ name: "max uint32 range",
+ min: 4294967294,
+ max: 4294967295,
+ expected: MagicHeader{Min: 4294967294, Max: 4294967295},
+ },
+ {
+ name: "min greater than max",
+ min: 10,
+ max: 5,
+ expected: MagicHeader{},
+ errorMsg: "min (10) cannot be greater than max (5)",
+ },
+ {
+ name: "large min greater than max",
+ min: 4294967295,
+ max: 1,
+ expected: MagicHeader{},
+ errorMsg: "min (4294967295) cannot be greater than max (1)",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result, err := NewMagicHeader(tt.min, tt.max)
+
+ if tt.errorMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errorMsg)
+ require.Equal(t, MagicHeader{}, result)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestParseMagicHeader(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ value string
+ expected MagicHeader
+ errorMsg string
+ }{
+ {
+ name: "single value",
+ key: "header1",
+ value: "100",
+ expected: MagicHeader{Min: 100, Max: 100},
+ },
+ {
+ name: "valid range",
+ key: "header2",
+ value: "10-20",
+ expected: MagicHeader{Min: 10, Max: 20},
+ },
+ {
+ name: "zero single value",
+ key: "header3",
+ value: "0",
+ expected: MagicHeader{Min: 0, Max: 0},
+ },
+ {
+ name: "zero range",
+ key: "header4",
+ value: "0-0",
+ expected: MagicHeader{Min: 0, Max: 0},
+ },
+ {
+ name: "max uint32 single",
+ key: "header5",
+ value: "4294967295",
+ expected: MagicHeader{Min: 4294967295, Max: 4294967295},
+ },
+ {
+ name: "max uint32 range",
+ key: "header6",
+ value: "4294967294-4294967295",
+ expected: MagicHeader{Min: 4294967294, Max: 4294967295},
+ },
+ {
+ name: "invalid single value - not number",
+ key: "header7",
+ value: "abc",
+ expected: MagicHeader{},
+ errorMsg: "parse key: header7; value: abc;",
+ },
+ {
+ name: "invalid single value - negative",
+ key: "header8",
+ value: "-5",
+ expected: MagicHeader{},
+ errorMsg: "invalid value for key: header8; value: -5;",
+ },
+ {
+ name: "invalid single value - too large",
+ key: "header9",
+ value: "4294967296",
+ expected: MagicHeader{},
+ errorMsg: "parse key: header9; value: 4294967296;",
+ },
+ {
+ name: "invalid range - min not number",
+ key: "header10",
+ value: "abc-10",
+ expected: MagicHeader{},
+ errorMsg: "parse min key: header10; value: abc;",
+ },
+ {
+ name: "invalid range - max not number",
+ key: "header11",
+ value: "10-abc",
+ expected: MagicHeader{},
+ errorMsg: "parse max key: header11; value: abc;",
+ },
+ {
+ name: "invalid range - min greater than max",
+ key: "header12",
+ value: "20-10",
+ expected: MagicHeader{},
+ errorMsg: "new magicHeader key: header12; value: 20-10;",
+ },
+ {
+ name: "invalid range - too many parts",
+ key: "header13",
+ value: "10-20-30",
+ expected: MagicHeader{},
+ errorMsg: "parse key: header13; value: 10-20-30;",
+ },
+ {
+ name: "empty value",
+ key: "header14",
+ value: "",
+ expected: MagicHeader{},
+ errorMsg: "parse key: header14; value: ;",
+ },
+ {
+ name: "hyphen only",
+ key: "header15",
+ value: "-",
+ expected: MagicHeader{},
+ errorMsg: "invalid value for key: header15; value: -;",
+ },
+ {
+ name: "empty min",
+ key: "header16",
+ value: "-10",
+ expected: MagicHeader{},
+ errorMsg: "invalid value for key: header16; value: -10;",
+ },
+ {
+ name: "empty max",
+ key: "header17",
+ value: "10-",
+ expected: MagicHeader{},
+ errorMsg: "invalid value for key: header17; value: 10-;",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result, err := ParseMagicHeader(tt.key, tt.value)
+
+ if tt.errorMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errorMsg)
+ require.Equal(t, MagicHeader{}, result)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestNewMagicHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ magicHeaders []MagicHeader
+ errorMsg string
+ }{
+ {
+ name: "valid non-overlapping headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 11, Max: 20},
+ {Min: 21, Max: 30},
+ {Min: 31, Max: 40},
+ },
+ },
+ {
+ name: "valid adjacent headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 1},
+ {Min: 2, Max: 2},
+ {Min: 3, Max: 3},
+ {Min: 4, Max: 4},
+ },
+ },
+ {
+ name: "valid zero-based headers",
+ magicHeaders: []MagicHeader{
+ {Min: 0, Max: 0},
+ {Min: 1, Max: 1},
+ {Min: 2, Max: 2},
+ {Min: 3, Max: 3},
+ },
+ },
+ {
+ name: "valid large value headers",
+ magicHeaders: []MagicHeader{
+ {Min: 4294967290, Max: 4294967291},
+ {Min: 4294967292, Max: 4294967293},
+ {Min: 4294967294, Max: 4294967294},
+ {Min: 4294967295, Max: 4294967295},
+ },
+ },
+ {
+ name: "too few headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 11, Max: 20},
+ {Min: 21, Max: 30},
+ },
+ errorMsg: "all header types should be included:",
+ },
+ {
+ name: "too many headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 11, Max: 20},
+ {Min: 21, Max: 30},
+ {Min: 31, Max: 40},
+ {Min: 41, Max: 50},
+ },
+ errorMsg: "all header types should be included:",
+ },
+ {
+ name: "empty headers",
+ magicHeaders: []MagicHeader{},
+ errorMsg: "all header types should be included:",
+ },
+ {
+ name: "overlapping headers",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 15},
+ {Min: 10, Max: 20},
+ {Min: 25, Max: 30},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ {
+ name: "overlapping headers at limit-first",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 10, Max: 20},
+ {Min: 25, Max: 30},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ {
+ name: "overlapping headers at limit-second",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 15, Max: 25},
+ {Min: 25, Max: 30},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ {
+ name: "overlapping headers at limit-third",
+ magicHeaders: []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 15, Max: 25},
+ {Min: 30, Max: 35},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ {
+ name: "identical ranges",
+ magicHeaders: []MagicHeader{
+ {Min: 10, Max: 20},
+ {Min: 10, Max: 20},
+ {Min: 25, Max: 30},
+ {Min: 35, Max: 40},
+ },
+ errorMsg: "magic headers shouldn't overlap;",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result, err := NewMagicHeaders(tt.magicHeaders)
+
+ if tt.errorMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errorMsg)
+ require.Equal(t, MagicHeaders{}, result)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tt.magicHeaders, result.headerValues)
+ require.NotNil(t, result.randomGenerator)
+ }
+ })
+ }
+}
+
+// Mock PRNG for testing
+type mockPRNG struct {
+ returnValue uint32
+}
+
+func (m *mockPRNG) RandomSizeInRange(min, max uint32) uint32 {
+ return m.returnValue
+}
+
+func (m *mockPRNG) Get() uint64 {
+ return 0
+}
+func (m *mockPRNG) ReadSize(size int) []byte {
+ return make([]byte, size)
+}
+
+func TestMagicHeaders_Get(t *testing.T) {
+ // Create test headers
+ headers := []MagicHeader{
+ {Min: 1, Max: 10},
+ {Min: 11, Max: 20},
+ {Min: 21, Max: 30},
+ {Min: 31, Max: 40},
+ }
+
+ tests := []struct {
+ name string
+ defaultMsgType uint32
+ mockValue uint32
+ expectedValue uint32
+ errorMsg string
+ }{
+ {
+ name: "valid type 1",
+ defaultMsgType: 1,
+ mockValue: 5,
+ expectedValue: 5,
+ },
+ {
+ name: "valid type 2",
+ defaultMsgType: 2,
+ mockValue: 15,
+ expectedValue: 15,
+ },
+ {
+ name: "valid type 3",
+ defaultMsgType: 3,
+ mockValue: 25,
+ expectedValue: 25,
+ },
+ {
+ name: "valid type 4",
+ defaultMsgType: 4,
+ mockValue: 35,
+ expectedValue: 35,
+ },
+ {
+ name: "invalid type 0",
+ defaultMsgType: 0,
+ mockValue: 0,
+ expectedValue: 0,
+ errorMsg: "invalid msg type: 0",
+ },
+ {
+ name: "invalid type 5",
+ defaultMsgType: 5,
+ mockValue: 0,
+ expectedValue: 0,
+ errorMsg: "invalid msg type: 5",
+ },
+ {
+ name: "invalid type max uint32",
+ defaultMsgType: 4294967295,
+ mockValue: 0,
+ expectedValue: 0,
+ errorMsg: "invalid msg type: 4294967295",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ // Create a new instance with mock PRNG for each test
+ testMagicHeaders := MagicHeaders{
+ headerValues: headers,
+ randomGenerator: &mockPRNG{returnValue: tt.mockValue},
+ }
+
+ result, err := testMagicHeaders.Get(tt.defaultMsgType)
+
+ if tt.errorMsg != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.errorMsg)
+ require.Equal(t, uint32(0), result)
+ } else {
+ require.NoError(t, err)
+ require.Equal(t, tt.expectedValue, result)
+ }
+ })
+ }
+}
diff --git a/device/awg/util.go b/device/awg/prng.go
similarity index 84%
rename from device/awg/util.go
rename to device/awg/prng.go
index 164e2b2..e7661d7 100644
--- a/device/awg/util.go
+++ b/device/awg/prng.go
@@ -7,6 +7,12 @@ import (
"golang.org/x/exp/constraints"
)
+type RandomNumberGenerator[T constraints.Integer] interface {
+ RandomSizeInRange(min, max T) T
+ Get() uint64
+ ReadSize(size int) []byte
+}
+
type PRNG[T constraints.Integer] struct {
cha8Rand *v2.ChaCha8
}
diff --git a/device/device.go b/device/device.go
index 1f1c035..1a2b726 100644
--- a/device/device.go
+++ b/device/device.go
@@ -883,7 +883,7 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize]
func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
msgTypeRange := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
- msgType, err := device.awg.GetLimitMin(msgTypeRange)
+ msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeRange)
if err != nil {
return 0, fmt.Errorf("get limit min: %w", err)