mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-08-02 17:52:50 +02:00
feat: restructure random value generation
This commit is contained in:
parent
c5312e2740
commit
cc5cfcdb25
3 changed files with 79 additions and 78 deletions
|
@ -2,59 +2,42 @@ package awg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
crand "crypto/rand"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
v2 "math/rand/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type junkCreator struct {
|
type junkCreator struct {
|
||||||
aSecCfg aSecCfgType
|
aSecCfg aSecCfgType
|
||||||
cha8Rand *v2.ChaCha8
|
randomGenerator PRNG[int]
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: refactor param to only pass the junk related params
|
// TODO: refactor param to only pass the junk related params
|
||||||
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
|
func NewJunkCreator(aSecCfg aSecCfgType) junkCreator {
|
||||||
buf := make([]byte, 32)
|
return junkCreator{aSecCfg: aSecCfg, randomGenerator: NewPRNG[int]()}
|
||||||
_, err := crand.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return junkCreator{}, err
|
|
||||||
}
|
|
||||||
return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
// Should be called with aSecMux RLocked
|
||||||
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error {
|
func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) {
|
||||||
if jc.aSecCfg.JunkPacketCount == 0 {
|
if jc.aSecCfg.JunkPacketCount == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for range jc.aSecCfg.JunkPacketCount {
|
for range jc.aSecCfg.JunkPacketCount {
|
||||||
packetSize := jc.randomPacketSize()
|
packetSize := jc.randomPacketSize()
|
||||||
junk, err := jc.randomJunkWithSize(packetSize)
|
junk := jc.randomJunkWithSize(packetSize)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create junk packet: %v", err)
|
|
||||||
}
|
|
||||||
*junks = append(*junks, junk)
|
*junks = append(*junks, junk)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
// Should be called with aSecMux RLocked
|
||||||
func (jc *junkCreator) randomPacketSize() int {
|
func (jc *junkCreator) randomPacketSize() int {
|
||||||
return int(
|
return jc.randomGenerator.RandomSizeInRange(jc.aSecCfg.JunkPacketMinSize, jc.aSecCfg.JunkPacketMaxSize)
|
||||||
jc.cha8Rand.Uint64()%uint64(
|
|
||||||
jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
|
|
||||||
),
|
|
||||||
) + jc.aSecCfg.JunkPacketMinSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
// Should be called with aSecMux RLocked
|
||||||
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
||||||
headerJunk, err := jc.randomJunkWithSize(size)
|
headerJunk := jc.randomJunkWithSize(size)
|
||||||
if err != nil {
|
_, err := writer.Write(headerJunk)
|
||||||
return fmt.Errorf("create header junk: %v", err)
|
|
||||||
}
|
|
||||||
_, err = writer.Write(headerJunk)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("write header junk: %v", err)
|
return fmt.Errorf("write header junk: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -62,9 +45,6 @@ func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
// Should be called with aSecMux RLocked
|
||||||
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
|
func (jc *junkCreator) randomJunkWithSize(size int) []byte {
|
||||||
// TODO: use a memory pool to allocate
|
return jc.randomGenerator.ReadSize(size)
|
||||||
junk := make([]byte, size)
|
|
||||||
_, err := jc.cha8Rand.Read(junk)
|
|
||||||
return junk, err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,43 +6,29 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
|
func setUpJunkCreator() junkCreator {
|
||||||
jc, err := NewJunkCreator(aSecCfgType{
|
jc := NewJunkCreator(aSecCfgType{
|
||||||
IsSet: true,
|
IsSet: true,
|
||||||
JunkPacketCount: 5,
|
JunkPacketCount: 5,
|
||||||
JunkPacketMinSize: 500,
|
JunkPacketMinSize: 500,
|
||||||
JunkPacketMaxSize: 1000,
|
JunkPacketMaxSize: 1000,
|
||||||
InitHeaderJunkSize: 30,
|
InitHeaderJunkSize: 30,
|
||||||
ResponseHeaderJunkSize: 40,
|
ResponseHeaderJunkSize: 40,
|
||||||
InitPacketMagicHeader: 123456,
|
// TODO
|
||||||
ResponsePacketMagicHeader: 67543,
|
// InitPacketMagicHeader: 123456,
|
||||||
UnderloadPacketMagicHeader: 32345,
|
// ResponsePacketMagicHeader: 67543,
|
||||||
TransportPacketMagicHeader: 123123,
|
// UnderloadPacketMagicHeader: 32345,
|
||||||
|
// TransportPacketMagicHeader: 123123,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
return jc
|
||||||
t.Errorf("failed to create junk creator %v", err)
|
|
||||||
return junkCreator{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return jc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||||
jc, err := setUpJunkCreator(t)
|
jc := setUpJunkCreator()
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
|
got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount)
|
||||||
err := jc.CreateJunkPackets(&got)
|
jc.CreateJunkPackets(&got)
|
||||||
if err != nil {
|
|
||||||
t.Errorf(
|
|
||||||
"junkCreator.createJunkPackets() = %v; failed",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
for _, junk := range got {
|
for _, junk := range got {
|
||||||
key := string(junk)
|
key := string(junk)
|
||||||
|
@ -61,25 +47,19 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||||
|
|
||||||
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
jc, err := setUpJunkCreator(t)
|
jc := setUpJunkCreator()
|
||||||
if err != nil {
|
r1 := jc.randomJunkWithSize(10)
|
||||||
return
|
r2 := jc.randomJunkWithSize(10)
|
||||||
}
|
|
||||||
r1, _ := jc.randomJunkWithSize(10)
|
|
||||||
r2, _ := jc.randomJunkWithSize(10)
|
|
||||||
fmt.Printf("%v\n%v\n", r1, r2)
|
fmt.Printf("%v\n%v\n", r1, r2)
|
||||||
if bytes.Equal(r1, r2) {
|
if bytes.Equal(r1, r2) {
|
||||||
t.Errorf("same junks %v", err)
|
t.Errorf("same junks")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||||
jc, err := setUpJunkCreator(t)
|
jc := setUpJunkCreator()
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for range [30]struct{}{} {
|
for range [30]struct{}{} {
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
|
if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
|
||||||
|
@ -96,10 +76,7 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_junkCreator_appendJunk(t *testing.T) {
|
func Test_junkCreator_appendJunk(t *testing.T) {
|
||||||
jc, err := setUpJunkCreator(t)
|
jc := setUpJunkCreator()
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
s := "apple"
|
s := "apple"
|
||||||
buffer := bytes.NewBuffer([]byte(s))
|
buffer := bytes.NewBuffer([]byte(s))
|
||||||
|
|
44
device/awg/util.go
Normal file
44
device/awg/util.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package awg
|
||||||
|
|
||||||
|
import (
|
||||||
|
crand "crypto/rand"
|
||||||
|
v2 "math/rand/v2"
|
||||||
|
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PRNG[T constraints.Integer] struct {
|
||||||
|
cha8Rand *v2.ChaCha8
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPRNG[T constraints.Integer]() PRNG[T] {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
_, _ = crand.Read(buf)
|
||||||
|
|
||||||
|
return PRNG[T]{
|
||||||
|
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PRNG[T]) RandomSizeInRange(min, max T) T {
|
||||||
|
if min >= max {
|
||||||
|
panic("min must be less than max")
|
||||||
|
}
|
||||||
|
|
||||||
|
if min == max {
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
|
||||||
|
return T(p.Get()%uint64(max-min)) + min
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PRNG[T]) Get() uint64 {
|
||||||
|
return p.cha8Rand.Uint64()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PRNG[T]) ReadSize(size int) []byte {
|
||||||
|
// TODO: use a memory pool to allocate
|
||||||
|
buf := make([]byte, size)
|
||||||
|
_, _ = p.cha8Rand.Read(buf)
|
||||||
|
return buf
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue