From 70292c0ae3c601941f881cf91291b865f54cbefb Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Mon, 4 Aug 2025 18:18:19 +0200
Subject: [PATCH] feat: add different random generators
---
device/awg/tag_generator.go | 99 ++++++++++++++++++----
device/awg/tag_generator_test.go | 140 ++++++++++++++++++++++++++++++-
device/awg/tag_parser.go | 6 +-
device/device_test.go | 14 ++--
4 files changed, 232 insertions(+), 27 deletions(-)
diff --git a/device/awg/tag_generator.go b/device/awg/tag_generator.go
index 65d8004..3924ccb 100644
--- a/device/awg/tag_generator.go
+++ b/device/awg/tag_generator.go
@@ -59,43 +59,110 @@ func hexToBytes(hexStr string) ([]byte, error) {
return hex.DecodeString(hexStr)
}
-type RandomPacketGenerator struct {
+type randomGeneratorBase struct {
cha8Rand *v2.ChaCha8
size int
}
-func (rpg *RandomPacketGenerator) Generate() []byte {
- junk := make([]byte, rpg.size)
- rpg.cha8Rand.Read(junk)
- return junk
-}
-
-func (rpg *RandomPacketGenerator) Size() int {
- return rpg.size
-}
-
-func newRandomPacketGenerator(param string) (Generator, error) {
+func newRandomGeneratorBase(param string) (*randomGeneratorBase, error) {
size, err := strconv.Atoi(param)
if err != nil {
- return nil, fmt.Errorf("random packet parse int: %w", err)
+ return nil, fmt.Errorf("parse int: %w", err)
}
if size > 1000 {
- return nil, fmt.Errorf("random packet size must be less than 1000")
+ return nil, fmt.Errorf("size must be less than 1000")
}
buf := make([]byte, 32)
_, err = crand.Read(buf)
if err != nil {
- return nil, fmt.Errorf("random packet crand read: %w", err)
+ return nil, fmt.Errorf("crand read: %w", err)
}
- return &RandomPacketGenerator{
+ return &randomGeneratorBase{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
size: size,
}, nil
}
+func (rpg *randomGeneratorBase) generate() []byte {
+ junk := make([]byte, rpg.size)
+ rpg.cha8Rand.Read(junk)
+ return junk
+}
+
+func (rpg *randomGeneratorBase) Size() int {
+ return rpg.size
+}
+
+type RandomBytesGenerator struct {
+ *randomGeneratorBase
+}
+
+func newRandomBytesGenerator(param string) (Generator, error) {
+ rpgBase, err := newRandomGeneratorBase(param)
+ if err != nil {
+ return nil, fmt.Errorf("new random bytes generator: %w", err)
+ }
+
+ return &RandomBytesGenerator{randomGeneratorBase: rpgBase}, nil
+}
+
+func (rpg *RandomBytesGenerator) Generate() []byte {
+ return rpg.generate()
+}
+
+const alphanumericChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
+
+type RandomASCIIGenerator struct {
+ *randomGeneratorBase
+}
+
+func newRandomASCIIGenerator(param string) (Generator, error) {
+ rpgBase, err := newRandomGeneratorBase(param)
+ if err != nil {
+ return nil, fmt.Errorf("new random ascii generator: %w", err)
+ }
+
+ return &RandomASCIIGenerator{randomGeneratorBase: rpgBase}, nil
+}
+
+func (rpg *RandomASCIIGenerator) Generate() []byte {
+ junk := rpg.generate()
+
+ result := make([]byte, rpg.size)
+ for i, b := range junk {
+ result[i] = alphanumericChars[b%byte(len(alphanumericChars))]
+ }
+
+ return result
+}
+
+type RandomDigitGenerator struct {
+ *randomGeneratorBase
+}
+
+func newRandomDigitGenerator(param string) (Generator, error) {
+ rpgBase, err := newRandomGeneratorBase(param)
+ if err != nil {
+ return nil, fmt.Errorf("new random digit generator: %w", err)
+ }
+
+ return &RandomDigitGenerator{randomGeneratorBase: rpgBase}, nil
+}
+
+func (rpg *RandomDigitGenerator) Generate() []byte {
+ junk := rpg.generate()
+
+ result := make([]byte, rpg.size)
+ for i, b := range junk {
+ result[i] = '0' + (b % 10) // Convert to digit character
+ }
+
+ return result
+}
+
type TimestampGenerator struct {
}
diff --git a/device/awg/tag_generator_test.go b/device/awg/tag_generator_test.go
index 4950b33..43efa67 100644
--- a/device/awg/tag_generator_test.go
+++ b/device/awg/tag_generator_test.go
@@ -8,7 +8,9 @@ import (
"github.com/stretchr/testify/require"
)
-func Test_newBytesGenerator(t *testing.T) {
+func TestNewBytesGenerator(t *testing.T) {
+ t.Parallel()
+
type args struct {
param string
}
@@ -63,6 +65,8 @@ func Test_newBytesGenerator(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
got, err := newBytesGenerator(tt.args.param)
if tt.wantErr != nil {
@@ -80,7 +84,9 @@ func Test_newBytesGenerator(t *testing.T) {
}
}
-func Test_newRandomPacketGenerator(t *testing.T) {
+func TestNewRandomBytesGenerator(t *testing.T) {
+ t.Parallel()
+
type args struct {
param string
}
@@ -117,9 +123,134 @@ func Test_newRandomPacketGenerator(t *testing.T) {
},
},
}
+
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got, err := newRandomPacketGenerator(tt.args.param)
+ t.Parallel()
+
+ got, err := newRandomBytesGenerator(tt.args.param)
+ if tt.wantErr != nil {
+ require.ErrorAs(t, err, &tt.wantErr)
+ require.Nil(t, got)
+ return
+ }
+
+ require.Nil(t, err)
+ require.NotNil(t, got)
+ first := got.Generate()
+
+ second := got.Generate()
+ require.NotEqual(t, first, second)
+ })
+ }
+}
+
+func TestNewRandomASCIIGenerator(t *testing.T) {
+ t.Parallel()
+
+ type args struct {
+ param string
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr error
+ }{
+ {
+ name: "empty",
+ args: args{
+ param: "",
+ },
+ wantErr: fmt.Errorf("parse int"),
+ },
+ {
+ name: "not an int",
+ args: args{
+ param: "x",
+ },
+ wantErr: fmt.Errorf("parse int"),
+ },
+ {
+ name: "too large",
+ args: args{
+ param: "1001",
+ },
+ wantErr: fmt.Errorf("random packet size must be less than 1000"),
+ },
+ {
+ name: "valid",
+ args: args{
+ param: "12",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := newRandomASCIIGenerator(tt.args.param)
+ if tt.wantErr != nil {
+ require.ErrorAs(t, err, &tt.wantErr)
+ require.Nil(t, got)
+ return
+ }
+
+ require.Nil(t, err)
+ require.NotNil(t, got)
+ first := got.Generate()
+
+ second := got.Generate()
+ require.NotEqual(t, first, second)
+ })
+ }
+}
+
+func TestNewRandomDigitGenerator(t *testing.T) {
+ t.Parallel()
+
+ type args struct {
+ param string
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr error
+ }{
+ {
+ name: "empty",
+ args: args{
+ param: "",
+ },
+ wantErr: fmt.Errorf("parse int"),
+ },
+ {
+ name: "not an int",
+ args: args{
+ param: "x",
+ },
+ wantErr: fmt.Errorf("parse int"),
+ },
+ {
+ name: "too large",
+ args: args{
+ param: "1001",
+ },
+ wantErr: fmt.Errorf("random packet size must be less than 1000"),
+ },
+ {
+ name: "valid",
+ args: args{
+ param: "12",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := newRandomDigitGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
@@ -137,6 +268,8 @@ func Test_newRandomPacketGenerator(t *testing.T) {
}
func TestPacketCounterGenerator(t *testing.T) {
+ t.Parallel()
+
tests := []struct {
name string
param string
@@ -155,7 +288,6 @@ func TestPacketCounterGenerator(t *testing.T) {
}
for _, tc := range tests {
- tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go
index c2e1090..d36823f 100644
--- a/device/awg/tag_parser.go
+++ b/device/awg/tag_parser.go
@@ -16,6 +16,8 @@ const (
CounterEnumTag EnumTag = "c"
TimestampEnumTag EnumTag = "t"
RandomBytesEnumTag EnumTag = "r"
+ RandomASCIIEnumTag EnumTag = "rc"
+ RandomDigitEnumTag EnumTag = "rd"
WaitTimeoutEnumTag EnumTag = "wt"
WaitResponseEnumTag EnumTag = "wr"
)
@@ -24,7 +26,9 @@ var generatorCreator = map[EnumTag]newGenerator{
BytesEnumTag: newBytesGenerator,
CounterEnumTag: newPacketCounterGenerator,
TimestampEnumTag: newTimestampGenerator,
- RandomBytesEnumTag: newRandomPacketGenerator,
+ RandomBytesEnumTag: newRandomBytesGenerator,
+ RandomASCIIEnumTag: newRandomASCIIGenerator,
+ RandomDigitEnumTag: newRandomDigitGenerator,
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
// WaitResponseEnumTag: newWaitResponseGenerator,
}
diff --git a/device/device_test.go b/device/device_test.go
index 5820128..8d7421e 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -264,12 +264,14 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
goroutineLeakCheck(t)
pair := genTestPair(t, true,
- "i1", "",
- "i2", "",
- "j1", "",
- "j2", "",
- "j3", "",
- "itime", "60",
+ "i1", "",
+ "i2", "",
+ "i3", "",
+ "i4", "",
+ // "j1", "",
+ // "j2", "",
+ // "j3", "",
+ // "itime", "60",
// "jc", "1",
// "jmin", "500",
// "jmax", "1000",