diff --git a/device/awg/junk_creator.go b/device/awg/junk_creator.go index 91fd253..b41441a 100644 --- a/device/awg/junk_creator.go +++ b/device/awg/junk_creator.go @@ -2,59 +2,42 @@ package awg import ( "bytes" - crand "crypto/rand" "fmt" - v2 "math/rand/v2" ) type junkCreator struct { - aSecCfg aSecCfgType - cha8Rand *v2.ChaCha8 + aSecCfg aSecCfgType + randomGenerator PRNG[int] } // TODO: refactor param to only pass the junk related params -func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) { - buf := make([]byte, 32) - _, err := crand.Read(buf) - if err != nil { - return junkCreator{}, err - } - return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil +func NewJunkCreator(aSecCfg aSecCfgType) junkCreator { + return junkCreator{aSecCfg: aSecCfg, randomGenerator: NewPRNG[int]()} } // Should be called with aSecMux RLocked -func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error { +func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) { if jc.aSecCfg.JunkPacketCount == 0 { - return nil + return } for range jc.aSecCfg.JunkPacketCount { packetSize := jc.randomPacketSize() - junk, err := jc.randomJunkWithSize(packetSize) - if err != nil { - return fmt.Errorf("create junk packet: %v", err) - } + junk := jc.randomJunkWithSize(packetSize) *junks = append(*junks, junk) } - return nil + return } // Should be called with aSecMux RLocked func (jc *junkCreator) randomPacketSize() int { - return int( - jc.cha8Rand.Uint64()%uint64( - jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize, - ), - ) + jc.aSecCfg.JunkPacketMinSize + return jc.randomGenerator.RandomSizeInRange(jc.aSecCfg.JunkPacketMinSize, jc.aSecCfg.JunkPacketMaxSize) } // Should be called with aSecMux RLocked func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { - headerJunk, err := jc.randomJunkWithSize(size) - if err != nil { - return fmt.Errorf("create header junk: %v", err) - } - _, err = writer.Write(headerJunk) + headerJunk := jc.randomJunkWithSize(size) + _, err := writer.Write(headerJunk) if err != nil { return fmt.Errorf("write header junk: %v", err) } @@ -62,9 +45,6 @@ func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { } // Should be called with aSecMux RLocked -func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { - // TODO: use a memory pool to allocate - junk := make([]byte, size) - _, err := jc.cha8Rand.Read(junk) - return junk, err +func (jc *junkCreator) randomJunkWithSize(size int) []byte { + return jc.randomGenerator.ReadSize(size) } diff --git a/device/awg/junk_creator_test.go b/device/awg/junk_creator_test.go index 424f104..3553a93 100644 --- a/device/awg/junk_creator_test.go +++ b/device/awg/junk_creator_test.go @@ -6,43 +6,29 @@ import ( "testing" ) -func setUpJunkCreator(t *testing.T) (junkCreator, error) { - jc, err := NewJunkCreator(aSecCfgType{ - IsSet: true, - JunkPacketCount: 5, - JunkPacketMinSize: 500, - JunkPacketMaxSize: 1000, - InitHeaderJunkSize: 30, - ResponseHeaderJunkSize: 40, - InitPacketMagicHeader: 123456, - ResponsePacketMagicHeader: 67543, - UnderloadPacketMagicHeader: 32345, - TransportPacketMagicHeader: 123123, +func setUpJunkCreator() junkCreator { + jc := NewJunkCreator(aSecCfgType{ + IsSet: true, + JunkPacketCount: 5, + JunkPacketMinSize: 500, + JunkPacketMaxSize: 1000, + InitHeaderJunkSize: 30, + ResponseHeaderJunkSize: 40, + // TODO + // InitPacketMagicHeader: 123456, + // ResponsePacketMagicHeader: 67543, + // UnderloadPacketMagicHeader: 32345, + // TransportPacketMagicHeader: 123123, }) - if err != nil { - t.Errorf("failed to create junk creator %v", err) - return junkCreator{}, err - } - - return jc, nil + return jc } func Test_junkCreator_createJunkPackets(t *testing.T) { - jc, err := setUpJunkCreator(t) - if err != nil { - return - } + jc := setUpJunkCreator() t.Run("valid", func(t *testing.T) { got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount) - err := jc.CreateJunkPackets(&got) - if err != nil { - t.Errorf( - "junkCreator.createJunkPackets() = %v; failed", - err, - ) - return - } + jc.CreateJunkPackets(&got) seen := make(map[string]bool) for _, junk := range got { key := string(junk) @@ -61,25 +47,19 @@ func Test_junkCreator_createJunkPackets(t *testing.T) { func Test_junkCreator_randomJunkWithSize(t *testing.T) { t.Run("valid", func(t *testing.T) { - jc, err := setUpJunkCreator(t) - if err != nil { - return - } - r1, _ := jc.randomJunkWithSize(10) - r2, _ := jc.randomJunkWithSize(10) + jc := setUpJunkCreator() + r1 := jc.randomJunkWithSize(10) + r2 := jc.randomJunkWithSize(10) fmt.Printf("%v\n%v\n", r1, r2) if bytes.Equal(r1, r2) { - t.Errorf("same junks %v", err) + t.Errorf("same junks") return } }) } func Test_junkCreator_randomPacketSize(t *testing.T) { - jc, err := setUpJunkCreator(t) - if err != nil { - return - } + jc := setUpJunkCreator() for range [30]struct{}{} { t.Run("valid", func(t *testing.T) { if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got || @@ -96,10 +76,7 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { } func Test_junkCreator_appendJunk(t *testing.T) { - jc, err := setUpJunkCreator(t) - if err != nil { - return - } + jc := setUpJunkCreator() t.Run("valid", func(t *testing.T) { s := "apple" buffer := bytes.NewBuffer([]byte(s)) diff --git a/device/awg/util.go b/device/awg/util.go new file mode 100644 index 0000000..5cb0caa --- /dev/null +++ b/device/awg/util.go @@ -0,0 +1,44 @@ +package awg + +import ( + crand "crypto/rand" + v2 "math/rand/v2" + + "golang.org/x/exp/constraints" +) + +type PRNG[T constraints.Integer] struct { + cha8Rand *v2.ChaCha8 +} + +func NewPRNG[T constraints.Integer]() PRNG[T] { + buf := make([]byte, 32) + _, _ = crand.Read(buf) + + return PRNG[T]{ + cha8Rand: v2.NewChaCha8([32]byte(buf)), + } +} + +func (p PRNG[T]) RandomSizeInRange(min, max T) T { + if min >= max { + panic("min must be less than max") + } + + if min == max { + return min + } + + return T(p.Get()%uint64(max-min)) + min +} + +func (p PRNG[T]) Get() uint64 { + return p.cha8Rand.Uint64() +} + +func (p PRNG[T]) ReadSize(size int) []byte { + // TODO: use a memory pool to allocate + buf := make([]byte, size) + _, _ = p.cha8Rand.Read(buf) + return buf +}