diff --git a/device/awg/awg.go b/device/awg/awg.go index 6138c5f..336617e 100644 --- a/device/awg/awg.go +++ b/device/awg/awg.go @@ -82,14 +82,14 @@ func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, return junk, nil } -func (protocol *Protocol) GetLimitMin(msgTypeRange uint32) (uint32, error) { - for _, limit := range protocol.MagicHeaders.headers { +func (protocol *Protocol) GetMagicHeaderMinFor(msgTypeRange uint32) (uint32, error) { + for _, limit := range protocol.MagicHeaders.headerValues { if limit.Min <= msgTypeRange && msgTypeRange <= limit.Max { return limit.Min, nil } } - return 0, fmt.Errorf("no limit for range: %d", msgTypeRange) + return 0, fmt.Errorf("no header for range: %d", msgTypeRange) } func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) { diff --git a/device/awg/magic_header.go b/device/awg/magic_header.go index f7ad45c..df8bdf1 100644 --- a/device/awg/magic_header.go +++ b/device/awg/magic_header.go @@ -35,6 +35,8 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) { } return NewMagicHeader(uint32(magicHeader), uint32(magicHeader)) + } else if len(splitLimits[0]) == 0 || len(splitLimits[1]) == 0 { + return MagicHeader{}, fmt.Errorf("invalid value for key: %s; value: %s; expected format: min-max", key, value) } min, err := strconv.ParseUint(splitLimits[0], 10, 32) @@ -56,30 +58,30 @@ func ParseMagicHeader(key, value string) (MagicHeader, error) { } type MagicHeaders struct { - headers []MagicHeader - randomGenerator PRNG[uint32] + headerValues []MagicHeader + randomGenerator RandomNumberGenerator[uint32] } -func NewMagicHeaders(magicHeaders []MagicHeader) (MagicHeaders, error) { - if len(magicHeaders) != 4 { - return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", magicHeaders) +func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) { + if len(headerValues) != 4 { + return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", headerValues) } - sortedMagicHeaders := slices.SortedFunc(slices.Values(magicHeaders), func(lhs MagicHeader, rhs MagicHeader) int { + sortedMagicHeaders := slices.SortedFunc(slices.Values(headerValues), func(lhs MagicHeader, rhs MagicHeader) int { return cmp.Compare(lhs.Min, rhs.Min) }) for i := range 3 { - if sortedMagicHeaders[i].Min > sortedMagicHeaders[i+1].Min { + if sortedMagicHeaders[i].Max >= sortedMagicHeaders[i+1].Min { return MagicHeaders{}, fmt.Errorf( "magic headers shouldn't overlap; %v > %v", - sortedMagicHeaders[i-1].Min, - sortedMagicHeaders[i].Min, + sortedMagicHeaders[i].Max, + sortedMagicHeaders[i+1].Min, ) } } - return MagicHeaders{headers: magicHeaders, randomGenerator: NewPRNG[uint32]()}, nil + return MagicHeaders{headerValues: headerValues, randomGenerator: NewPRNG[uint32]()}, nil } func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) { @@ -87,5 +89,5 @@ func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) { return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType) } - return mh.randomGenerator.RandomSizeInRange(mh.headers[defaultMsgType-1].Min, mh.headers[defaultMsgType-1].Max), nil + return mh.randomGenerator.RandomSizeInRange(mh.headerValues[defaultMsgType-1].Min, mh.headerValues[defaultMsgType-1].Max), nil } diff --git a/device/awg/magic_header_test.go b/device/awg/magic_header_test.go new file mode 100644 index 0000000..d9e608c --- /dev/null +++ b/device/awg/magic_header_test.go @@ -0,0 +1,488 @@ +package awg + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMagicHeaderSameValue(t *testing.T) { + tests := []struct { + name string + value uint32 + expected MagicHeader + }{ + { + name: "zero value", + value: 0, + expected: MagicHeader{Min: 0, Max: 0}, + }, + { + name: "small value", + value: 1, + expected: MagicHeader{Min: 1, Max: 1}, + }, + { + name: "large value", + value: 4294967295, // max uint32 + expected: MagicHeader{Min: 4294967295, Max: 4294967295}, + }, + { + name: "medium value", + value: 1000, + expected: MagicHeader{Min: 1000, Max: 1000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := NewMagicHeaderSameValue(tt.value) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestNewMagicHeader(t *testing.T) { + tests := []struct { + name string + min uint32 + max uint32 + expected MagicHeader + errorMsg string + }{ + { + name: "valid range", + min: 1, + max: 10, + expected: MagicHeader{Min: 1, Max: 10}, + }, + { + name: "equal values", + min: 5, + max: 5, + expected: MagicHeader{Min: 5, Max: 5}, + }, + { + name: "zero range", + min: 0, + max: 0, + expected: MagicHeader{Min: 0, Max: 0}, + }, + { + name: "max uint32 range", + min: 4294967294, + max: 4294967295, + expected: MagicHeader{Min: 4294967294, Max: 4294967295}, + }, + { + name: "min greater than max", + min: 10, + max: 5, + expected: MagicHeader{}, + errorMsg: "min (10) cannot be greater than max (5)", + }, + { + name: "large min greater than max", + min: 4294967295, + max: 1, + expected: MagicHeader{}, + errorMsg: "min (4294967295) cannot be greater than max (1)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result, err := NewMagicHeader(tt.min, tt.max) + + if tt.errorMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + require.Equal(t, MagicHeader{}, result) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +} + +func TestParseMagicHeader(t *testing.T) { + tests := []struct { + name string + key string + value string + expected MagicHeader + errorMsg string + }{ + { + name: "single value", + key: "header1", + value: "100", + expected: MagicHeader{Min: 100, Max: 100}, + }, + { + name: "valid range", + key: "header2", + value: "10-20", + expected: MagicHeader{Min: 10, Max: 20}, + }, + { + name: "zero single value", + key: "header3", + value: "0", + expected: MagicHeader{Min: 0, Max: 0}, + }, + { + name: "zero range", + key: "header4", + value: "0-0", + expected: MagicHeader{Min: 0, Max: 0}, + }, + { + name: "max uint32 single", + key: "header5", + value: "4294967295", + expected: MagicHeader{Min: 4294967295, Max: 4294967295}, + }, + { + name: "max uint32 range", + key: "header6", + value: "4294967294-4294967295", + expected: MagicHeader{Min: 4294967294, Max: 4294967295}, + }, + { + name: "invalid single value - not number", + key: "header7", + value: "abc", + expected: MagicHeader{}, + errorMsg: "parse key: header7; value: abc;", + }, + { + name: "invalid single value - negative", + key: "header8", + value: "-5", + expected: MagicHeader{}, + errorMsg: "invalid value for key: header8; value: -5;", + }, + { + name: "invalid single value - too large", + key: "header9", + value: "4294967296", + expected: MagicHeader{}, + errorMsg: "parse key: header9; value: 4294967296;", + }, + { + name: "invalid range - min not number", + key: "header10", + value: "abc-10", + expected: MagicHeader{}, + errorMsg: "parse min key: header10; value: abc;", + }, + { + name: "invalid range - max not number", + key: "header11", + value: "10-abc", + expected: MagicHeader{}, + errorMsg: "parse max key: header11; value: abc;", + }, + { + name: "invalid range - min greater than max", + key: "header12", + value: "20-10", + expected: MagicHeader{}, + errorMsg: "new magicHeader key: header12; value: 20-10;", + }, + { + name: "invalid range - too many parts", + key: "header13", + value: "10-20-30", + expected: MagicHeader{}, + errorMsg: "parse key: header13; value: 10-20-30;", + }, + { + name: "empty value", + key: "header14", + value: "", + expected: MagicHeader{}, + errorMsg: "parse key: header14; value: ;", + }, + { + name: "hyphen only", + key: "header15", + value: "-", + expected: MagicHeader{}, + errorMsg: "invalid value for key: header15; value: -;", + }, + { + name: "empty min", + key: "header16", + value: "-10", + expected: MagicHeader{}, + errorMsg: "invalid value for key: header16; value: -10;", + }, + { + name: "empty max", + key: "header17", + value: "10-", + expected: MagicHeader{}, + errorMsg: "invalid value for key: header17; value: 10-;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result, err := ParseMagicHeader(tt.key, tt.value) + + if tt.errorMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + require.Equal(t, MagicHeader{}, result) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +} + +func TestNewMagicHeaders(t *testing.T) { + tests := []struct { + name string + magicHeaders []MagicHeader + errorMsg string + }{ + { + name: "valid non-overlapping headers", + magicHeaders: []MagicHeader{ + {Min: 1, Max: 10}, + {Min: 11, Max: 20}, + {Min: 21, Max: 30}, + {Min: 31, Max: 40}, + }, + }, + { + name: "valid adjacent headers", + magicHeaders: []MagicHeader{ + {Min: 1, Max: 1}, + {Min: 2, Max: 2}, + {Min: 3, Max: 3}, + {Min: 4, Max: 4}, + }, + }, + { + name: "valid zero-based headers", + magicHeaders: []MagicHeader{ + {Min: 0, Max: 0}, + {Min: 1, Max: 1}, + {Min: 2, Max: 2}, + {Min: 3, Max: 3}, + }, + }, + { + name: "valid large value headers", + magicHeaders: []MagicHeader{ + {Min: 4294967290, Max: 4294967291}, + {Min: 4294967292, Max: 4294967293}, + {Min: 4294967294, Max: 4294967294}, + {Min: 4294967295, Max: 4294967295}, + }, + }, + { + name: "too few headers", + magicHeaders: []MagicHeader{ + {Min: 1, Max: 10}, + {Min: 11, Max: 20}, + {Min: 21, Max: 30}, + }, + errorMsg: "all header types should be included:", + }, + { + name: "too many headers", + magicHeaders: []MagicHeader{ + {Min: 1, Max: 10}, + {Min: 11, Max: 20}, + {Min: 21, Max: 30}, + {Min: 31, Max: 40}, + {Min: 41, Max: 50}, + }, + errorMsg: "all header types should be included:", + }, + { + name: "empty headers", + magicHeaders: []MagicHeader{}, + errorMsg: "all header types should be included:", + }, + { + name: "overlapping headers", + magicHeaders: []MagicHeader{ + {Min: 1, Max: 15}, + {Min: 10, Max: 20}, + {Min: 25, Max: 30}, + {Min: 35, Max: 40}, + }, + errorMsg: "magic headers shouldn't overlap;", + }, + { + name: "overlapping headers at limit-first", + magicHeaders: []MagicHeader{ + {Min: 1, Max: 10}, + {Min: 10, Max: 20}, + {Min: 25, Max: 30}, + {Min: 35, Max: 40}, + }, + errorMsg: "magic headers shouldn't overlap;", + }, + { + name: "overlapping headers at limit-second", + magicHeaders: []MagicHeader{ + {Min: 1, Max: 10}, + {Min: 15, Max: 25}, + {Min: 25, Max: 30}, + {Min: 35, Max: 40}, + }, + errorMsg: "magic headers shouldn't overlap;", + }, + { + name: "overlapping headers at limit-third", + magicHeaders: []MagicHeader{ + {Min: 1, Max: 10}, + {Min: 15, Max: 25}, + {Min: 30, Max: 35}, + {Min: 35, Max: 40}, + }, + errorMsg: "magic headers shouldn't overlap;", + }, + { + name: "identical ranges", + magicHeaders: []MagicHeader{ + {Min: 10, Max: 20}, + {Min: 10, Max: 20}, + {Min: 25, Max: 30}, + {Min: 35, Max: 40}, + }, + errorMsg: "magic headers shouldn't overlap;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result, err := NewMagicHeaders(tt.magicHeaders) + + if tt.errorMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + require.Equal(t, MagicHeaders{}, result) + } else { + require.NoError(t, err) + require.Equal(t, tt.magicHeaders, result.headerValues) + require.NotNil(t, result.randomGenerator) + } + }) + } +} + +// Mock PRNG for testing +type mockPRNG struct { + returnValue uint32 +} + +func (m *mockPRNG) RandomSizeInRange(min, max uint32) uint32 { + return m.returnValue +} + +func (m *mockPRNG) Get() uint64 { + return 0 +} +func (m *mockPRNG) ReadSize(size int) []byte { + return make([]byte, size) +} + +func TestMagicHeaders_Get(t *testing.T) { + // Create test headers + headers := []MagicHeader{ + {Min: 1, Max: 10}, + {Min: 11, Max: 20}, + {Min: 21, Max: 30}, + {Min: 31, Max: 40}, + } + + tests := []struct { + name string + defaultMsgType uint32 + mockValue uint32 + expectedValue uint32 + errorMsg string + }{ + { + name: "valid type 1", + defaultMsgType: 1, + mockValue: 5, + expectedValue: 5, + }, + { + name: "valid type 2", + defaultMsgType: 2, + mockValue: 15, + expectedValue: 15, + }, + { + name: "valid type 3", + defaultMsgType: 3, + mockValue: 25, + expectedValue: 25, + }, + { + name: "valid type 4", + defaultMsgType: 4, + mockValue: 35, + expectedValue: 35, + }, + { + name: "invalid type 0", + defaultMsgType: 0, + mockValue: 0, + expectedValue: 0, + errorMsg: "invalid msg type: 0", + }, + { + name: "invalid type 5", + defaultMsgType: 5, + mockValue: 0, + expectedValue: 0, + errorMsg: "invalid msg type: 5", + }, + { + name: "invalid type max uint32", + defaultMsgType: 4294967295, + mockValue: 0, + expectedValue: 0, + errorMsg: "invalid msg type: 4294967295", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Create a new instance with mock PRNG for each test + testMagicHeaders := MagicHeaders{ + headerValues: headers, + randomGenerator: &mockPRNG{returnValue: tt.mockValue}, + } + + result, err := testMagicHeaders.Get(tt.defaultMsgType) + + if tt.errorMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + require.Equal(t, uint32(0), result) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedValue, result) + } + }) + } +} diff --git a/device/awg/util.go b/device/awg/prng.go similarity index 84% rename from device/awg/util.go rename to device/awg/prng.go index 164e2b2..e7661d7 100644 --- a/device/awg/util.go +++ b/device/awg/prng.go @@ -7,6 +7,12 @@ import ( "golang.org/x/exp/constraints" ) +type RandomNumberGenerator[T constraints.Integer] interface { + RandomSizeInRange(min, max T) T + Get() uint64 + ReadSize(size int) []byte +} + type PRNG[T constraints.Integer] struct { cha8Rand *v2.ChaCha8 } diff --git a/device/device.go b/device/device.go index 1f1c035..1a2b726 100644 --- a/device/device.go +++ b/device/device.go @@ -883,7 +883,7 @@ func (device *Device) Logic(size int, packet *[]byte, bufsArrs *[MaxMessageSize] func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) { msgTypeRange := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4]) - msgType, err := device.awg.GetLimitMin(msgTypeRange) + msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeRange) if err != nil { return 0, fmt.Errorf("get limit min: %w", err)