feat: restructure random value generation

This commit is contained in:
Mark Puha 2025-07-08 19:12:26 +02:00
parent c5312e2740
commit cc5cfcdb25
3 changed files with 79 additions and 78 deletions

View file

@ -2,59 +2,42 @@ package awg
import ( import (
"bytes" "bytes"
crand "crypto/rand"
"fmt" "fmt"
v2 "math/rand/v2"
) )
type junkCreator struct { type junkCreator struct {
aSecCfg aSecCfgType aSecCfg aSecCfgType
cha8Rand *v2.ChaCha8 randomGenerator PRNG[int]
} }
// TODO: refactor param to only pass the junk related params // TODO: refactor param to only pass the junk related params
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) { func NewJunkCreator(aSecCfg aSecCfgType) junkCreator {
buf := make([]byte, 32) return junkCreator{aSecCfg: aSecCfg, randomGenerator: NewPRNG[int]()}
_, err := crand.Read(buf)
if err != nil {
return junkCreator{}, err
}
return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error { func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) {
if jc.aSecCfg.JunkPacketCount == 0 { if jc.aSecCfg.JunkPacketCount == 0 {
return nil return
} }
for range jc.aSecCfg.JunkPacketCount { for range jc.aSecCfg.JunkPacketCount {
packetSize := jc.randomPacketSize() packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize) junk := jc.randomJunkWithSize(packetSize)
if err != nil {
return fmt.Errorf("create junk packet: %v", err)
}
*junks = append(*junks, junk) *junks = append(*junks, junk)
} }
return nil return
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) randomPacketSize() int { func (jc *junkCreator) randomPacketSize() int {
return int( return jc.randomGenerator.RandomSizeInRange(jc.aSecCfg.JunkPacketMinSize, jc.aSecCfg.JunkPacketMaxSize)
jc.cha8Rand.Uint64()%uint64(
jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
),
) + jc.aSecCfg.JunkPacketMinSize
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size) headerJunk := jc.randomJunkWithSize(size)
if err != nil { _, err := writer.Write(headerJunk)
return fmt.Errorf("create header junk: %v", err)
}
_, err = writer.Write(headerJunk)
if err != nil { if err != nil {
return fmt.Errorf("write header junk: %v", err) return fmt.Errorf("write header junk: %v", err)
} }
@ -62,9 +45,6 @@ func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { func (jc *junkCreator) randomJunkWithSize(size int) []byte {
// TODO: use a memory pool to allocate return jc.randomGenerator.ReadSize(size)
junk := make([]byte, size)
_, err := jc.cha8Rand.Read(junk)
return junk, err
} }

View file

@ -6,43 +6,29 @@ import (
"testing" "testing"
) )
func setUpJunkCreator(t *testing.T) (junkCreator, error) { func setUpJunkCreator() junkCreator {
jc, err := NewJunkCreator(aSecCfgType{ jc := NewJunkCreator(aSecCfgType{
IsSet: true, IsSet: true,
JunkPacketCount: 5, JunkPacketCount: 5,
JunkPacketMinSize: 500, JunkPacketMinSize: 500,
JunkPacketMaxSize: 1000, JunkPacketMaxSize: 1000,
InitHeaderJunkSize: 30, InitHeaderJunkSize: 30,
ResponseHeaderJunkSize: 40, ResponseHeaderJunkSize: 40,
InitPacketMagicHeader: 123456, // TODO
ResponsePacketMagicHeader: 67543, // InitPacketMagicHeader: 123456,
UnderloadPacketMagicHeader: 32345, // ResponsePacketMagicHeader: 67543,
TransportPacketMagicHeader: 123123, // UnderloadPacketMagicHeader: 32345,
// TransportPacketMagicHeader: 123123,
}) })
if err != nil { return jc
t.Errorf("failed to create junk creator %v", err)
return junkCreator{}, err
}
return jc, nil
} }
func Test_junkCreator_createJunkPackets(t *testing.T) { func Test_junkCreator_createJunkPackets(t *testing.T) {
jc, err := setUpJunkCreator(t) jc := setUpJunkCreator()
if err != nil {
return
}
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount) got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
err := jc.CreateJunkPackets(&got) jc.CreateJunkPackets(&got)
if err != nil {
t.Errorf(
"junkCreator.createJunkPackets() = %v; failed",
err,
)
return
}
seen := make(map[string]bool) seen := make(map[string]bool)
for _, junk := range got { for _, junk := range got {
key := string(junk) key := string(junk)
@ -61,25 +47,19 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
func Test_junkCreator_randomJunkWithSize(t *testing.T) { func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
jc, err := setUpJunkCreator(t) jc := setUpJunkCreator()
if err != nil { r1 := jc.randomJunkWithSize(10)
return r2 := jc.randomJunkWithSize(10)
}
r1, _ := jc.randomJunkWithSize(10)
r2, _ := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2) fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) { if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err) t.Errorf("same junks")
return return
} }
}) })
} }
func Test_junkCreator_randomPacketSize(t *testing.T) { func Test_junkCreator_randomPacketSize(t *testing.T) {
jc, err := setUpJunkCreator(t) jc := setUpJunkCreator()
if err != nil {
return
}
for range [30]struct{}{} { for range [30]struct{}{} {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got || if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
@ -96,10 +76,7 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
} }
func Test_junkCreator_appendJunk(t *testing.T) { func Test_junkCreator_appendJunk(t *testing.T) {
jc, err := setUpJunkCreator(t) jc := setUpJunkCreator()
if err != nil {
return
}
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
s := "apple" s := "apple"
buffer := bytes.NewBuffer([]byte(s)) buffer := bytes.NewBuffer([]byte(s))

44
device/awg/util.go Normal file
View file

@ -0,0 +1,44 @@
package awg
import (
crand "crypto/rand"
v2 "math/rand/v2"
"golang.org/x/exp/constraints"
)
type PRNG[T constraints.Integer] struct {
cha8Rand *v2.ChaCha8
}
func NewPRNG[T constraints.Integer]() PRNG[T] {
buf := make([]byte, 32)
_, _ = crand.Read(buf)
return PRNG[T]{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
}
}
func (p PRNG[T]) RandomSizeInRange(min, max T) T {
if min >= max {
panic("min must be less than max")
}
if min == max {
return min
}
return T(p.Get()%uint64(max-min)) + min
}
func (p PRNG[T]) Get() uint64 {
return p.cha8Rand.Uint64()
}
func (p PRNG[T]) ReadSize(size int) []byte {
// TODO: use a memory pool to allocate
buf := make([]byte, size)
_, _ = p.cha8Rand.Read(buf)
return buf
}