mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-08-02 09:52:49 +02:00
feat: restructure random value generation
This commit is contained in:
parent
c5312e2740
commit
cc5cfcdb25
3 changed files with 79 additions and 78 deletions
|
@ -2,59 +2,42 @@ package awg
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"fmt"
|
||||
v2 "math/rand/v2"
|
||||
)
|
||||
|
||||
type junkCreator struct {
|
||||
aSecCfg aSecCfgType
|
||||
cha8Rand *v2.ChaCha8
|
||||
randomGenerator PRNG[int]
|
||||
}
|
||||
|
||||
// TODO: refactor param to only pass the junk related params
|
||||
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
|
||||
buf := make([]byte, 32)
|
||||
_, err := crand.Read(buf)
|
||||
if err != nil {
|
||||
return junkCreator{}, err
|
||||
}
|
||||
return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
|
||||
func NewJunkCreator(aSecCfg aSecCfgType) junkCreator {
|
||||
return junkCreator{aSecCfg: aSecCfg, randomGenerator: NewPRNG[int]()}
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
|
||||
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) {
|
||||
if jc.aSecCfg.JunkPacketCount == 0 {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
for range jc.aSecCfg.JunkPacketCount {
|
||||
packetSize := jc.randomPacketSize()
|
||||
junk, err := jc.randomJunkWithSize(packetSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create junk packet: %v", err)
|
||||
}
|
||||
junk := jc.randomJunkWithSize(packetSize)
|
||||
*junks = append(*junks, junk)
|
||||
}
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) randomPacketSize() int {
|
||||
return int(
|
||||
jc.cha8Rand.Uint64()%uint64(
|
||||
jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
|
||||
),
|
||||
) + jc.aSecCfg.JunkPacketMinSize
|
||||
return jc.randomGenerator.RandomSizeInRange(jc.aSecCfg.JunkPacketMinSize, jc.aSecCfg.JunkPacketMaxSize)
|
||||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
||||
headerJunk, err := jc.randomJunkWithSize(size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create header junk: %v", err)
|
||||
}
|
||||
_, err = writer.Write(headerJunk)
|
||||
headerJunk := jc.randomJunkWithSize(size)
|
||||
_, err := writer.Write(headerJunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write header junk: %v", err)
|
||||
}
|
||||
|
@ -62,9 +45,6 @@ func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
|||
}
|
||||
|
||||
// Should be called with aSecMux RLocked
|
||||
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
|
||||
// TODO: use a memory pool to allocate
|
||||
junk := make([]byte, size)
|
||||
_, err := jc.cha8Rand.Read(junk)
|
||||
return junk, err
|
||||
func (jc *junkCreator) randomJunkWithSize(size int) []byte {
|
||||
return jc.randomGenerator.ReadSize(size)
|
||||
}
|
||||
|
|
|
@ -6,43 +6,29 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
|
||||
jc, err := NewJunkCreator(aSecCfgType{
|
||||
func setUpJunkCreator() junkCreator {
|
||||
jc := NewJunkCreator(aSecCfgType{
|
||||
IsSet: true,
|
||||
JunkPacketCount: 5,
|
||||
JunkPacketMinSize: 500,
|
||||
JunkPacketMaxSize: 1000,
|
||||
InitHeaderJunkSize: 30,
|
||||
ResponseHeaderJunkSize: 40,
|
||||
InitPacketMagicHeader: 123456,
|
||||
ResponsePacketMagicHeader: 67543,
|
||||
UnderloadPacketMagicHeader: 32345,
|
||||
TransportPacketMagicHeader: 123123,
|
||||
// TODO
|
||||
// InitPacketMagicHeader: 123456,
|
||||
// ResponsePacketMagicHeader: 67543,
|
||||
// UnderloadPacketMagicHeader: 32345,
|
||||
// TransportPacketMagicHeader: 123123,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("failed to create junk creator %v", err)
|
||||
return junkCreator{}, err
|
||||
}
|
||||
|
||||
return jc, nil
|
||||
return jc
|
||||
}
|
||||
|
||||
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
jc := setUpJunkCreator()
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
|
||||
err := jc.CreateJunkPackets(&got)
|
||||
if err != nil {
|
||||
t.Errorf(
|
||||
"junkCreator.createJunkPackets() = %v; failed",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
jc.CreateJunkPackets(&got)
|
||||
seen := make(map[string]bool)
|
||||
for _, junk := range got {
|
||||
key := string(junk)
|
||||
|
@ -61,25 +47,19 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
|
|||
|
||||
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r1, _ := jc.randomJunkWithSize(10)
|
||||
r2, _ := jc.randomJunkWithSize(10)
|
||||
jc := setUpJunkCreator()
|
||||
r1 := jc.randomJunkWithSize(10)
|
||||
r2 := jc.randomJunkWithSize(10)
|
||||
fmt.Printf("%v\n%v\n", r1, r2)
|
||||
if bytes.Equal(r1, r2) {
|
||||
t.Errorf("same junks %v", err)
|
||||
t.Errorf("same junks")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
jc := setUpJunkCreator()
|
||||
for range [30]struct{}{} {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
|
||||
|
@ -96,10 +76,7 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_junkCreator_appendJunk(t *testing.T) {
|
||||
jc, err := setUpJunkCreator(t)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
jc := setUpJunkCreator()
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
s := "apple"
|
||||
buffer := bytes.NewBuffer([]byte(s))
|
||||
|
|
44
device/awg/util.go
Normal file
44
device/awg/util.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package awg
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
v2 "math/rand/v2"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type PRNG[T constraints.Integer] struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
}
|
||||
|
||||
func NewPRNG[T constraints.Integer]() PRNG[T] {
|
||||
buf := make([]byte, 32)
|
||||
_, _ = crand.Read(buf)
|
||||
|
||||
return PRNG[T]{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
}
|
||||
}
|
||||
|
||||
func (p PRNG[T]) RandomSizeInRange(min, max T) T {
|
||||
if min >= max {
|
||||
panic("min must be less than max")
|
||||
}
|
||||
|
||||
if min == max {
|
||||
return min
|
||||
}
|
||||
|
||||
return T(p.Get()%uint64(max-min)) + min
|
||||
}
|
||||
|
||||
func (p PRNG[T]) Get() uint64 {
|
||||
return p.cha8Rand.Uint64()
|
||||
}
|
||||
|
||||
func (p PRNG[T]) ReadSize(size int) []byte {
|
||||
// TODO: use a memory pool to allocate
|
||||
buf := make([]byte, size)
|
||||
_, _ = p.cha8Rand.Read(buf)
|
||||
return buf
|
||||
}
|
Loading…
Add table
Reference in a new issue