mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-09-03 00:23:00 +02:00
feat: add different random generators
This commit is contained in:
parent
6639458d0a
commit
70292c0ae3
4 changed files with 232 additions and 27 deletions
|
@ -59,43 +59,110 @@ func hexToBytes(hexStr string) ([]byte, error) {
|
|||
return hex.DecodeString(hexStr)
|
||||
}
|
||||
|
||||
type RandomPacketGenerator struct {
|
||||
type randomGeneratorBase struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
size int
|
||||
}
|
||||
|
||||
func (rpg *RandomPacketGenerator) Generate() []byte {
|
||||
junk := make([]byte, rpg.size)
|
||||
rpg.cha8Rand.Read(junk)
|
||||
return junk
|
||||
}
|
||||
|
||||
func (rpg *RandomPacketGenerator) Size() int {
|
||||
return rpg.size
|
||||
}
|
||||
|
||||
func newRandomPacketGenerator(param string) (Generator, error) {
|
||||
func newRandomGeneratorBase(param string) (*randomGeneratorBase, error) {
|
||||
size, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("random packet parse int: %w", err)
|
||||
return nil, fmt.Errorf("parse int: %w", err)
|
||||
}
|
||||
|
||||
if size > 1000 {
|
||||
return nil, fmt.Errorf("random packet size must be less than 1000")
|
||||
return nil, fmt.Errorf("size must be less than 1000")
|
||||
}
|
||||
|
||||
buf := make([]byte, 32)
|
||||
_, err = crand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("random packet crand read: %w", err)
|
||||
return nil, fmt.Errorf("crand read: %w", err)
|
||||
}
|
||||
|
||||
return &RandomPacketGenerator{
|
||||
return &randomGeneratorBase{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
size: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rpg *randomGeneratorBase) generate() []byte {
|
||||
junk := make([]byte, rpg.size)
|
||||
rpg.cha8Rand.Read(junk)
|
||||
return junk
|
||||
}
|
||||
|
||||
func (rpg *randomGeneratorBase) Size() int {
|
||||
return rpg.size
|
||||
}
|
||||
|
||||
type RandomBytesGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomBytesGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random bytes generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomBytesGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomBytesGenerator) Generate() []byte {
|
||||
return rpg.generate()
|
||||
}
|
||||
|
||||
const alphanumericChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
type RandomASCIIGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomASCIIGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random ascii generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomASCIIGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomASCIIGenerator) Generate() []byte {
|
||||
junk := rpg.generate()
|
||||
|
||||
result := make([]byte, rpg.size)
|
||||
for i, b := range junk {
|
||||
result[i] = alphanumericChars[b%byte(len(alphanumericChars))]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type RandomDigitGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomDigitGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random digit generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomDigitGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomDigitGenerator) Generate() []byte {
|
||||
junk := rpg.generate()
|
||||
|
||||
result := make([]byte, rpg.size)
|
||||
for i, b := range junk {
|
||||
result[i] = '0' + (b % 10) // Convert to digit character
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type TimestampGenerator struct {
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,9 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_newBytesGenerator(t *testing.T) {
|
||||
func TestNewBytesGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
|
@ -63,6 +65,8 @@ func Test_newBytesGenerator(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newBytesGenerator(tt.args.param)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
|
@ -80,7 +84,9 @@ func Test_newBytesGenerator(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_newRandomPacketGenerator(t *testing.T) {
|
||||
func TestNewRandomBytesGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
|
@ -117,9 +123,134 @@ func Test_newRandomPacketGenerator(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newRandomPacketGenerator(tt.args.param)
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomBytesGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRandomASCIIGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomASCIIGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRandomDigitGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomDigitGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
|
@ -137,6 +268,8 @@ func Test_newRandomPacketGenerator(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPacketCounterGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
param string
|
||||
|
@ -155,7 +288,6 @@ func TestPacketCounterGenerator(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@ const (
|
|||
CounterEnumTag EnumTag = "c"
|
||||
TimestampEnumTag EnumTag = "t"
|
||||
RandomBytesEnumTag EnumTag = "r"
|
||||
RandomASCIIEnumTag EnumTag = "rc"
|
||||
RandomDigitEnumTag EnumTag = "rd"
|
||||
WaitTimeoutEnumTag EnumTag = "wt"
|
||||
WaitResponseEnumTag EnumTag = "wr"
|
||||
)
|
||||
|
@ -24,7 +26,9 @@ var generatorCreator = map[EnumTag]newGenerator{
|
|||
BytesEnumTag: newBytesGenerator,
|
||||
CounterEnumTag: newPacketCounterGenerator,
|
||||
TimestampEnumTag: newTimestampGenerator,
|
||||
RandomBytesEnumTag: newRandomPacketGenerator,
|
||||
RandomBytesEnumTag: newRandomBytesGenerator,
|
||||
RandomASCIIEnumTag: newRandomASCIIGenerator,
|
||||
RandomDigitEnumTag: newRandomDigitGenerator,
|
||||
WaitTimeoutEnumTag: newWaitTimeoutGenerator,
|
||||
// WaitResponseEnumTag: newWaitResponseGenerator,
|
||||
}
|
||||
|
|
|
@ -264,12 +264,14 @@ func TestAWGHandshakeDevicePing(t *testing.T) {
|
|||
|
||||
goroutineLeakCheck(t)
|
||||
pair := genTestPair(t, true,
|
||||
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>",
|
||||
"i2", "<b 0xf6ab3267fa><r 100>",
|
||||
"j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
|
||||
"j2", "<c><b 0xf6ab><t><wt 1000>",
|
||||
"j3", "<t><b 0xf6ab><c><r 10>",
|
||||
"itime", "60",
|
||||
"i1", "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10>",
|
||||
"i2", "<b 0xf6ab3267fa><c><b 0xf6ab><t><rc 10>",
|
||||
"i3", "<b 0xf6ab3267fa><c><b 0xf6ab><t><rd 10>",
|
||||
"i4", "<b 0xf6ab3267fa><r 100>",
|
||||
// "j1", "<b 0xffffffff><c><b 0xf6ab><t><r 10>",
|
||||
// "j2", "<c><b 0xf6ab><t><wt 1000>",
|
||||
// "j3", "<t><b 0xf6ab><c><r 10>",
|
||||
// "itime", "60",
|
||||
// "jc", "1",
|
||||
// "jmin", "500",
|
||||
// "jmax", "1000",
|
||||
|
|
Loading…
Add table
Reference in a new issue