mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-07-31 17:12:49 +02:00
feat: some generators & parser improvements
This commit is contained in:
parent
d96900ba17
commit
431b7b1a37
4 changed files with 253 additions and 38 deletions
|
@ -1,14 +1,106 @@
|
|||
package junktag
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v2 "math/rand/v2"
|
||||
)
|
||||
|
||||
type Generator interface {
|
||||
Generate() []byte
|
||||
Generate() ([]byte, error)
|
||||
}
|
||||
|
||||
type newGenerator func(string) (Generator, error)
|
||||
|
||||
type BytesGenerator struct {
|
||||
value []byte
|
||||
}
|
||||
|
||||
func (b *BytesGenerator) Generate() []byte {
|
||||
return nil
|
||||
func (bg *BytesGenerator) Generate() ([]byte, error) {
|
||||
return bg.value, nil
|
||||
}
|
||||
|
||||
func newBytesGenerator(param string) (Generator, error) {
|
||||
isNotHex := !strings.HasPrefix(param, "0x") ||
|
||||
!strings.HasPrefix(param, "0x") && !isHexString(param)
|
||||
if isNotHex {
|
||||
return nil, fmt.Errorf("not correct hex: %s", param)
|
||||
}
|
||||
|
||||
hex, err := hexToBytes(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hexToBytes: %w", err)
|
||||
}
|
||||
|
||||
return &BytesGenerator{value: hex}, nil
|
||||
}
|
||||
|
||||
func isHexString(s string) bool {
|
||||
for _, char := range s {
|
||||
if !((char >= '0' && char <= '9') ||
|
||||
(char >= 'a' && char <= 'f') ||
|
||||
(char >= 'A' && char <= 'F')) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(s) > 0
|
||||
}
|
||||
|
||||
func hexToBytes(hexStr string) ([]byte, error) {
|
||||
hexStr = strings.TrimPrefix(hexStr, "0x")
|
||||
hexStr = strings.TrimPrefix(hexStr, "0X")
|
||||
|
||||
// Ensure even length (pad with leading zero if needed)
|
||||
if len(hexStr)%2 != 0 {
|
||||
hexStr = "0" + hexStr
|
||||
}
|
||||
|
||||
return hex.DecodeString(hexStr)
|
||||
}
|
||||
|
||||
type RandomPacketGenerator struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
size int
|
||||
}
|
||||
|
||||
func (rpg *RandomPacketGenerator) Generate() ([]byte, error) {
|
||||
junk := make([]byte, rpg.size)
|
||||
_, err := rpg.cha8Rand.Read(junk)
|
||||
return junk, err
|
||||
}
|
||||
|
||||
func newRandomPacketGenerator(param string) (Generator, error) {
|
||||
size, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("randome packet parse int: %w", err)
|
||||
}
|
||||
// TODO: add size check
|
||||
|
||||
buf := make([]byte, 32)
|
||||
_, err = crand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("randome packet crand read: %w", err)
|
||||
}
|
||||
|
||||
return &RandomPacketGenerator{cha8Rand: v2.NewChaCha8([32]byte(buf)), size: size}, nil
|
||||
}
|
||||
|
||||
type TimestampGenerator struct {
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Generate() ([]byte, error) {
|
||||
return time.Now().MarshalBinary()
|
||||
}
|
||||
|
||||
func newTimestampGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("timestamp param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &TimestampGenerator{}, nil
|
||||
}
|
||||
|
|
124
device/internal/junk-tag/generator_test.go
Normal file
124
device/internal/junk-tag/generator_test.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package junktag
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_newBytesGenerator(t *testing.T) {
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "wrong start",
|
||||
args: args{
|
||||
param: "123456",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value",
|
||||
args: args{
|
||||
param: "0x12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "valid hex",
|
||||
args: args{
|
||||
param: "0xf6ab3267fa",
|
||||
},
|
||||
want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
{
|
||||
name: "valid hex with odd length",
|
||||
args: args{
|
||||
param: "0xfab3267fa",
|
||||
},
|
||||
want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newBytesGenerator(tt.args.param)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
gotValues, _ := got.Generate()
|
||||
require.Equal(t, tt.want, gotValues)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newRandomPacketGenerator(t *testing.T) {
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newRandomPacketGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first, err := got.Generate()
|
||||
require.Nil(t, err)
|
||||
|
||||
second, err := got.Generate()
|
||||
require.Nil(t, err)
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -18,10 +18,10 @@ const (
|
|||
)
|
||||
|
||||
var validEnum = map[Enum]newGenerator{
|
||||
EnumBytes: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
||||
EnumBytes: newBytesGenerator,
|
||||
EnumCounter: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
||||
EnumTimestamp: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
||||
EnumRandomBytes: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
||||
EnumTimestamp: newTimestampGenerator,
|
||||
EnumRandomBytes: newRandomPacketGenerator,
|
||||
EnumWaitTimeout: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
||||
EnumWaitResponse: func(s string) (Generator, error) { return &BytesGenerator{}, nil },
|
||||
}
|
||||
|
@ -35,24 +35,19 @@ type Tag struct {
|
|||
Param string
|
||||
}
|
||||
|
||||
func parseTags(input string) ([]Tag, error) {
|
||||
func parseTag(input string) (Tag, error) {
|
||||
// Regular expression to match <tagname optional_param>
|
||||
re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`)
|
||||
|
||||
matches := re.FindAllStringSubmatch(input, -1)
|
||||
tags := make([]Tag, 0, len(matches))
|
||||
|
||||
for _, match := range matches {
|
||||
tag := Tag{
|
||||
Name: Enum(match[1]),
|
||||
}
|
||||
if len(match) > 2 && match[2] != "" {
|
||||
tag.Param = strings.TrimSpace(match[2])
|
||||
}
|
||||
tags = append(tags, tag)
|
||||
match := re.FindStringSubmatch(input)
|
||||
tag := Tag{
|
||||
Name: Enum(match[1]),
|
||||
}
|
||||
if len(match) > 2 && match[2] != "" {
|
||||
tag.Param = strings.TrimSpace(match[2])
|
||||
}
|
||||
|
||||
return tags, nil
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
func Parse(input string) (Foo, error) {
|
||||
|
@ -62,26 +57,29 @@ func Parse(input string) (Foo, error) {
|
|||
return Foo{}, fmt.Errorf("empty input: %s", input)
|
||||
}
|
||||
|
||||
for _, inputParam := range inputSlice[1:] {
|
||||
if len(inputParam) == 1 {
|
||||
// skip byproduct of split
|
||||
inputSlice = inputSlice[1:]
|
||||
rv := Foo{x: make([]Generator, 0, len(inputSlice))}
|
||||
|
||||
for _, inputParam := range inputSlice {
|
||||
if len(inputParam) <= 1 {
|
||||
return Foo{}, fmt.Errorf("empty tag in input: %s", inputSlice)
|
||||
} else if strings.Count(inputParam, ">") != 1 {
|
||||
return Foo{}, fmt.Errorf("ill formated input: %s", input)
|
||||
}
|
||||
|
||||
tags, _ := parseTags(inputParam)
|
||||
for _, tag := range tags {
|
||||
fmt.Printf("Tag: %s, Param: %s\n", tag.Name, tag.Param)
|
||||
gen, ok := validEnum[tag.Name]
|
||||
if !ok {
|
||||
return Foo{}, fmt.Errorf("invalid tag")
|
||||
}
|
||||
_, err := gen(tag.Param)
|
||||
if err != nil {
|
||||
return Foo{}, fmt.Errorf("")
|
||||
}
|
||||
tag, _ := parseTag(inputParam)
|
||||
fmt.Printf("Tag: %s, Param: %s\n", tag.Name, tag.Param)
|
||||
gen, ok := validEnum[tag.Name]
|
||||
if !ok {
|
||||
return Foo{}, fmt.Errorf("invalid tag: %s", tag.Name)
|
||||
}
|
||||
generator, err := gen(tag.Param)
|
||||
if err != nil {
|
||||
return Foo{}, fmt.Errorf("gen: %w", err)
|
||||
}
|
||||
rv.x = append(rv.x, generator)
|
||||
}
|
||||
|
||||
return Foo{}, nil
|
||||
return rv, nil
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ func TestParse(t *testing.T) {
|
|||
{
|
||||
name: "extra <",
|
||||
args: args{input: "<<b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "empty <>",
|
||||
|
@ -51,10 +51,11 @@ func TestParse(t *testing.T) {
|
|||
_, err := Parse(tt.args.input)
|
||||
|
||||
// TODO: ErrorAs doesn't work as you think
|
||||
// if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
// return
|
||||
// }
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue