chore: project restructure

This commit is contained in:
Mark Puha 2025-06-09 16:41:54 +02:00
parent a1d8adca48
commit 65743536a2
19 changed files with 356 additions and 385 deletions

30
device/awg/awg.go Normal file
View file

@ -0,0 +1,30 @@
package awg
import (
"sync"
"github.com/tevino/abool"
)
type Protocol struct {
IsASecOn abool.AtomicBool
// TODO: revision the need of the mutex
ASecMux sync.RWMutex
ASecCfg aSecCfgType
JunkCreator junkCreator
HandshakeHandler SpecialHandshakeHandler
}
type aSecCfgType struct {
IsSet bool
JunkPacketCount int
JunkPacketMinSize int
JunkPacketMaxSize int
InitPacketJunkSize int
ResponsePacketJunkSize int
InitPacketMagicHeader uint32
ResponsePacketMagicHeader uint32
UnderloadPacketMagicHeader uint32
TransportPacketMagicHeader uint32
}

View file

@ -1,4 +1,4 @@
package device package awg
import ( import (
"bytes" "bytes"
@ -8,33 +8,32 @@ import (
) )
type junkCreator struct { type junkCreator struct {
device *Device aSecCfg aSecCfgType
cha8Rand *v2.ChaCha8 cha8Rand *v2.ChaCha8
} }
func NewJunkCreator(d *Device) (junkCreator, error) { func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
buf := make([]byte, 32) buf := make([]byte, 32)
_, err := crand.Read(buf) _, err := crand.Read(buf)
if err != nil { if err != nil {
return junkCreator{}, err return junkCreator{}, err
} }
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) createJunkPackets(junks *[][]byte) error { func (jc *junkCreator) CreateJunkPackets(junks [][]byte) error {
if jc.device.awg.aSecCfg.junkPacketCount == 0 { if jc.aSecCfg.JunkPacketCount == 0 {
return nil return nil
} }
*junks = make([][]byte, len(*junks)+jc.device.awg.aSecCfg.junkPacketCount) for i := range jc.aSecCfg.JunkPacketCount {
for i := range jc.device.awg.aSecCfg.junkPacketCount {
packetSize := jc.randomPacketSize() packetSize := jc.randomPacketSize()
junk, err := jc.randomJunkWithSize(packetSize) junk, err := jc.randomJunkWithSize(packetSize)
if err != nil { if err != nil {
return fmt.Errorf("create junk packet: %v", err) return fmt.Errorf("create junk packet: %v", err)
} }
(*junks)[i] = junk junks[i] = junk
} }
return nil return nil
} }
@ -43,13 +42,13 @@ func (jc *junkCreator) createJunkPackets(junks *[][]byte) error {
func (jc *junkCreator) randomPacketSize() int { func (jc *junkCreator) randomPacketSize() int {
return int( return int(
jc.cha8Rand.Uint64()%uint64( jc.cha8Rand.Uint64()%uint64(
jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize, jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize,
), ),
) + jc.device.awg.aSecCfg.junkPacketMinSize ) + jc.aSecCfg.JunkPacketMinSize
} }
// Should be called with aSecMux RLocked // Should be called with aSecMux RLocked
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
headerJunk, err := jc.randomJunkWithSize(size) headerJunk, err := jc.randomJunkWithSize(size)
if err != nil { if err != nil {
return fmt.Errorf("create header junk: %v", err) return fmt.Errorf("create header junk: %v", err)

View file

@ -1,36 +1,27 @@
package device package awg
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"testing" "testing"
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
) )
func setUpJunkCreator(t *testing.T) (junkCreator, error) { func setUpJunkCreator(t *testing.T) (junkCreator, error) {
cfg, _ := genASecurityConfigs(t) jc, err := NewJunkCreator(aSecCfgType{
tun := tuntest.NewChannelTUN() IsSet: true,
binds := bindtest.NewChannelBinds() JunkPacketCount: 5,
level := LogLevelVerbose JunkPacketMinSize: 500,
dev := NewDevice( JunkPacketMaxSize: 1000,
tun.TUN(), InitPacketJunkSize: 30,
binds[0], ResponsePacketJunkSize: 40,
NewLogger(level, ""), InitPacketMagicHeader: 123456,
) ResponsePacketMagicHeader: 67543,
UnderloadPacketMagicHeader: 32345,
if err := dev.IpcSet(cfg[0]); err != nil { TransportPacketMagicHeader: 123123,
t.Errorf("failed to configure device %v", err) })
dev.Close()
return junkCreator{}, err
}
jc, err := NewJunkCreator(dev)
if err != nil { if err != nil {
t.Errorf("failed to create junk creator %v", err) t.Errorf("failed to create junk creator %v", err)
dev.Close()
return junkCreator{}, err return junkCreator{}, err
} }
@ -42,8 +33,9 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
if err != nil { if err != nil {
return return
} }
t.Run("", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
got, err := jc.createJunkPackets() got := make([][]byte, jc.aSecCfg.JunkPacketCount)
err := jc.CreateJunkPackets(got)
if err != nil { if err != nil {
t.Errorf( t.Errorf(
"junkCreator.createJunkPackets() = %v; failed", "junkCreator.createJunkPackets() = %v; failed",
@ -68,7 +60,7 @@ func Test_junkCreator_createJunkPackets(t *testing.T) {
} }
func Test_junkCreator_randomJunkWithSize(t *testing.T) { func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
jc, err := setUpJunkCreator(t) jc, err := setUpJunkCreator(t)
if err != nil { if err != nil {
return return
@ -78,7 +70,6 @@ func Test_junkCreator_randomJunkWithSize(t *testing.T) {
fmt.Printf("%v\n%v\n", r1, r2) fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) { if bytes.Equal(r1, r2) {
t.Errorf("same junks %v", err) t.Errorf("same junks %v", err)
jc.device.Close()
return return
} }
}) })
@ -90,14 +81,14 @@ func Test_junkCreator_randomPacketSize(t *testing.T) {
return return
} }
for range [30]struct{}{} { for range [30]struct{}{} {
t.Run("", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got || if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got ||
got > jc.device.awg.aSecCfg.junkPacketMaxSize { got > jc.aSecCfg.JunkPacketMaxSize {
t.Errorf( t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]", "junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got, got,
jc.device.awg.aSecCfg.junkPacketMinSize, jc.aSecCfg.JunkPacketMinSize,
jc.device.awg.aSecCfg.junkPacketMaxSize, jc.aSecCfg.JunkPacketMaxSize,
) )
} }
}) })
@ -109,13 +100,13 @@ func Test_junkCreator_appendJunk(t *testing.T) {
if err != nil { if err != nil {
return return
} }
t.Run("", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
s := "apple" s := "apple"
buffer := bytes.NewBuffer([]byte(s)) buffer := bytes.NewBuffer([]byte(s))
err := jc.appendJunk(buffer, 30) err := jc.AppendJunk(buffer, 30)
if err != nil && if err != nil &&
buffer.Len() != len(s)+30 { buffer.Len() != len(s)+30 {
t.Errorf("appendWithJunk() size don't match") t.Error("appendWithJunk() size don't match")
} }
read := make([]byte, 50) read := make([]byte, 50)
buffer.Read(read) buffer.Read(read)

View file

@ -1,4 +1,4 @@
package junktag package awg
import ( import (
"errors" "errors"
@ -6,13 +6,14 @@ import (
) )
type SpecialHandshakeHandler struct { type SpecialHandshakeHandler struct {
SpecialJunk TaggedJunkGeneratorHandler SpecialJunk TagJunkGeneratorHandler
ControlledJunk TaggedJunkGeneratorHandler ControlledJunk TagJunkGeneratorHandler
nextItime time.Time nextItime time.Time
ITimeout time.Duration // seconds ITimeout time.Duration // seconds
// TODO: maybe atomic? // TODO: maybe atomic?
PacketCounter uint64 PacketCounter uint64
IsSet bool
} }
func (handler *SpecialHandshakeHandler) Validate() error { func (handler *SpecialHandshakeHandler) Validate() error {

View file

@ -1,4 +1,4 @@
package junktag package awg
import ( import (
crand "crypto/rand" crand "crypto/rand"

View file

@ -1,4 +1,4 @@
package junktag package awg
import ( import (
"fmt" "fmt"

View file

@ -0,0 +1,46 @@
package awg
import (
"fmt"
"strconv"
)
type TagJunkGenerator struct {
name string
packetSize int
generators []Generator
}
func newTagJunkGenerator(name string, size int) TagJunkGenerator {
return TagJunkGenerator{name: name, generators: make([]Generator, size)}
}
func (tg *TagJunkGenerator) append(generator Generator) {
tg.generators = append(tg.generators, generator)
tg.packetSize += generator.Size()
}
func (tg *TagJunkGenerator) generate() []byte {
packet := make([]byte, 0, tg.packetSize)
for _, generator := range tg.generators {
packet = append(packet, generator.Generate()...)
}
return packet
}
func (tg *TagJunkGenerator) Name() string {
return tg.name
}
func (tg *TagJunkGenerator) nameIndex() (int, error) {
if len(tg.name) != 2 {
return 0, fmt.Errorf("name must be 2 character long: %s", tg.name)
}
index, err := strconv.Atoi(tg.name[1:2])
if err != nil {
return 0, fmt.Errorf("name should be 2 char long: %w", err)
}
return index, nil
}

View file

@ -1,19 +1,21 @@
package junktag package awg
import "fmt" import "fmt"
type TaggedJunkGeneratorHandler struct { type TagJunkGeneratorHandler struct {
generators []TaggedJunkGenerator generators []TagJunkGenerator
length int length int
// Jc
DefaultJunkCount int
} }
func (handler *TaggedJunkGeneratorHandler) AppendGenerator(generators TaggedJunkGenerator) { func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenerator) {
handler.generators = append(handler.generators, generators) handler.generators = append(handler.generators, generators)
handler.length++ handler.length++
} }
// validate that packets were defined consecutively // validate that packets were defined consecutively
func (handler *TaggedJunkGeneratorHandler) Validate() error { func (handler *TagJunkGeneratorHandler) Validate() error {
seen := make([]bool, len(handler.generators)) seen := make([]bool, len(handler.generators))
for _, generator := range handler.generators { for _, generator := range handler.generators {
if index, err := generator.nameIndex(); err != nil { if index, err := generator.nameIndex(); err != nil {
@ -32,8 +34,8 @@ func (handler *TaggedJunkGeneratorHandler) Validate() error {
return nil return nil
} }
func (handler *TaggedJunkGeneratorHandler) Generate() [][]byte { func (handler *TagJunkGeneratorHandler) Generate() [][]byte {
var rv = make([][]byte, handler.length) var rv = make([][]byte, handler.length+handler.DefaultJunkCount)
for i, generator := range handler.generators { for i, generator := range handler.generators {
rv[i] = make([]byte, generator.packetSize) rv[i] = make([]byte, generator.packetSize)
copy(rv[i], generator.generate()) copy(rv[i], generator.generate())

View file

@ -1,4 +1,4 @@
package junktag package awg
import ( import (
"fmt" "fmt"
@ -54,10 +54,10 @@ func parseTag(input string) (Tag, error) {
} }
// TODO: pointernes // TODO: pointernes
func Parse(name, input string) (TaggedJunkGenerator, error) { func Parse(name, input string) (TagJunkGenerator, error) {
inputSlice := strings.Split(input, "<") inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 { if len(inputSlice) <= 1 {
return TaggedJunkGenerator{}, fmt.Errorf("empty input: %s", input) return TagJunkGenerator{}, fmt.Errorf("empty input: %s", input)
} }
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags)) uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
@ -65,28 +65,28 @@ func Parse(name, input string) (TaggedJunkGenerator, error) {
// skip byproduct of split // skip byproduct of split
inputSlice = inputSlice[1:] inputSlice = inputSlice[1:]
rv := newTagedJunkGenerator(name, len(inputSlice)) rv := newTagJunkGenerator(name, len(inputSlice))
for _, inputParam := range inputSlice { for _, inputParam := range inputSlice {
if len(inputParam) <= 1 { if len(inputParam) <= 1 {
return TaggedJunkGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice) return TagJunkGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice)
} else if strings.Count(inputParam, ">") != 1 { } else if strings.Count(inputParam, ">") != 1 {
return TaggedJunkGenerator{}, fmt.Errorf("ill formated input: %s", input) return TagJunkGenerator{}, fmt.Errorf("ill formated input: %s", input)
} }
tag, _ := parseTag(inputParam) tag, _ := parseTag(inputParam)
creator, ok := generatorCreator[tag.Name] creator, ok := generatorCreator[tag.Name]
if !ok { if !ok {
return TaggedJunkGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name) return TagJunkGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
} }
if present, ok := uniqueTagCheck[tag.Name]; ok { if present, ok := uniqueTagCheck[tag.Name]; ok {
if present { if present {
return TaggedJunkGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name) return TagJunkGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name)
} }
uniqueTagCheck[tag.Name] = true uniqueTagCheck[tag.Name] = true
} }
generator, err := creator(tag.Param) generator, err := creator(tag.Param)
if err != nil { if err != nil {
return TaggedJunkGenerator{}, fmt.Errorf("gen: %w", err) return TagJunkGenerator{}, fmt.Errorf("gen: %w", err)
} }
rv.append(generator) rv.append(generator)

View file

@ -1,4 +1,4 @@
package junktag package awg
import ( import (
"fmt" "fmt"
@ -9,6 +9,7 @@ import (
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
type args struct { type args struct {
name string
input string input string
} }
tests := []struct { tests := []struct {
@ -16,39 +17,44 @@ func TestParse(t *testing.T) {
args args args args
wantErr error wantErr error
}{ }{
{
name: "invalid name",
args: args{name: "apple", input: ""},
wantErr: fmt.Errorf("ill formated input"),
},
{ {
name: "empty", name: "empty",
args: args{input: ""}, args: args{name: "i1", input: ""},
wantErr: fmt.Errorf("ill formated input"), wantErr: fmt.Errorf("ill formated input"),
}, },
{ {
name: "extra >", name: "extra >",
args: args{input: "<b 0xf6ab3267fa><c>>"}, args: args{name: "i1", input: "<b 0xf6ab3267fa><c>>"},
wantErr: fmt.Errorf("ill formated input"), wantErr: fmt.Errorf("ill formated input"),
}, },
{ {
name: "extra <", name: "extra <",
args: args{input: "<<b 0xf6ab3267fa><c>"}, args: args{name: "i1", input: "<<b 0xf6ab3267fa><c>"},
wantErr: fmt.Errorf("empty tag in input"), wantErr: fmt.Errorf("empty tag in input"),
}, },
{ {
name: "empty <>", name: "empty <>",
args: args{input: "<><b 0xf6ab3267fa><c>"}, args: args{name: "i1", input: "<><b 0xf6ab3267fa><c>"},
wantErr: fmt.Errorf("empty tag in input"), wantErr: fmt.Errorf("empty tag in input"),
}, },
{ {
name: "invalid tag", name: "invalid tag",
args: args{input: "<q 0xf6ab3267fa>"}, args: args{name: "i1", input: "<q 0xf6ab3267fa>"},
wantErr: fmt.Errorf("invalid tag"), wantErr: fmt.Errorf("invalid tag"),
}, },
{ {
name: "counter uniqueness violation", name: "counter uniqueness violation",
args: args{input: "<c><c>"}, args: args{name: "i1", input: "<c><c>"},
wantErr: fmt.Errorf("parse tag needs to be unique"), wantErr: fmt.Errorf("parse tag needs to be unique"),
}, },
{ {
name: "timestamp uniqueness violation", name: "timestamp uniqueness violation",
args: args{input: "<t><t>"}, args: args{name: "i1", input: "<t><t>"},
wantErr: fmt.Errorf("parse tag needs to be unique"), wantErr: fmt.Errorf("parse tag needs to be unique"),
}, },
{ {
@ -58,7 +64,7 @@ func TestParse(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := Parse(tt.args.input) _, err := Parse(tt.args.name, tt.args.input)
// TODO: ErrorAs doesn't work as you think // TODO: ErrorAs doesn't work as you think
if tt.wantErr != nil { if tt.wantErr != nil {

View file

@ -6,18 +6,18 @@
package device package device
import ( import (
"errors"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" "github.com/amnezia-vpn/amneziawg-go/device/awg"
"github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel" "github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun" "github.com/amnezia-vpn/amneziawg-go/tun"
"github.com/tevino/abool/v2"
) )
type Version uint8 type Version uint8
@ -129,31 +129,7 @@ type Device struct {
log *Logger log *Logger
version Version version Version
awg awg awg awg.Protocol
}
type awg struct {
isASecOn abool.AtomicBool
// TODO: revision the need of the mutex
aSecMux sync.RWMutex
aSecCfg aSecCfgType
junkCreator junkCreator
// TODO: determine if it's on
handshakeHandler junktag.SpecialHandshakeHandler
}
type aSecCfgType struct {
isSet bool
junkPacketCount int
junkPacketMinSize int
junkPacketMaxSize int
initPacketJunkSize int
responsePacketJunkSize int
initPacketMagicHeader uint32
responsePacketMagicHeader uint32
underloadPacketMagicHeader uint32
transportPacketMagicHeader uint32
} }
// deviceState represents the state of a Device. // deviceState represents the state of a Device.
@ -603,7 +579,7 @@ func (device *Device) BindClose() error {
return err return err
} }
func (device *Device) isAdvancedSecurityOn() bool { func (device *Device) isAdvancedSecurityOn() bool {
return device.awg.isASecOn.IsSet() return device.awg.IsASecOn.IsSet()
} }
func (device *Device) resetProtocol() { func (device *Device) resetProtocol() {
@ -614,165 +590,129 @@ func (device *Device) resetProtocol() {
MessageTransportType = DefaultMessageTransportType MessageTransportType = DefaultMessageTransportType
} }
func (device *Device) handlePostConfig(tempAwg *awg) (err error) { func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
if !tempAwg.aSecCfg.isSet { return nil
return err
} }
var errs []error
isASecOn := false isASecOn := false
device.awg.aSecMux.Lock() device.awg.ASecMux.Lock()
if tempAwg.aSecCfg.junkPacketCount < 0 { if tempAwg.ASecCfg.JunkPacketCount < 0 {
err = ipcErrorf( errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative", "JunkPacketCount should be non negative",
),
) )
} }
device.awg.aSecCfg.junkPacketCount = tempAwg.aSecCfg.junkPacketCount device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount
if tempAwg.aSecCfg.junkPacketCount != 0 { if tempAwg.ASecCfg.JunkPacketCount != 0 {
isASecOn = true isASecOn = true
} }
device.awg.aSecCfg.junkPacketMinSize = tempAwg.aSecCfg.junkPacketMinSize device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize
if tempAwg.aSecCfg.junkPacketMinSize != 0 { if tempAwg.ASecCfg.JunkPacketMinSize != 0 {
isASecOn = true isASecOn = true
} }
if device.awg.aSecCfg.junkPacketCount > 0 && if device.awg.ASecCfg.JunkPacketCount > 0 &&
tempAwg.aSecCfg.junkPacketMaxSize == tempAwg.aSecCfg.junkPacketMinSize { tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize {
tempAwg.aSecCfg.junkPacketMaxSize++ // to make rand gen work tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work
} }
if tempAwg.aSecCfg.junkPacketMaxSize >= MaxSegmentSize { if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize {
device.awg.aSecCfg.junkPacketMinSize = 0 device.awg.ASecCfg.JunkPacketMinSize = 0
device.awg.aSecCfg.junkPacketMaxSize = 1 device.awg.ASecCfg.JunkPacketMaxSize = 1
if err != nil { errs = append(errs, ipcErrorf(
err = ipcErrorf( ipc.IpcErrorInvalid,
ipc.IpcErrorInvalid, "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", tempAwg.ASecCfg.JunkPacketMaxSize,
tempAwg.aSecCfg.junkPacketMaxSize, MaxSegmentSize,
MaxSegmentSize, ))
err, } else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize {
) errs = append(errs, ipcErrorf(
} else { ipc.IpcErrorInvalid,
err = ipcErrorf( "maxSize: %d; should be greater than minSize: %d",
ipc.IpcErrorInvalid, tempAwg.ASecCfg.JunkPacketMaxSize,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", tempAwg.ASecCfg.JunkPacketMinSize,
tempAwg.aSecCfg.junkPacketMaxSize, ))
MaxSegmentSize,
)
}
} else if tempAwg.aSecCfg.junkPacketMaxSize < tempAwg.aSecCfg.junkPacketMinSize {
if err != nil {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d; %w",
tempAwg.aSecCfg.junkPacketMaxSize,
tempAwg.aSecCfg.junkPacketMinSize,
err,
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempAwg.aSecCfg.junkPacketMaxSize,
tempAwg.aSecCfg.junkPacketMinSize,
)
}
} else { } else {
device.awg.aSecCfg.junkPacketMaxSize = tempAwg.aSecCfg.junkPacketMaxSize device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize
} }
if tempAwg.aSecCfg.junkPacketMaxSize != 0 { if tempAwg.ASecCfg.JunkPacketMaxSize != 0 {
isASecOn = true isASecOn = true
} }
if MessageInitiationSize+tempAwg.aSecCfg.initPacketJunkSize >= MaxSegmentSize { if MessageInitiationSize+tempAwg.ASecCfg.InitPacketJunkSize >= MaxSegmentSize {
if err != nil { errs = append(errs, ipcErrorf(
err = ipcErrorf( ipc.IpcErrorInvalid,
ipc.IpcErrorInvalid, `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, tempAwg.ASecCfg.InitPacketJunkSize,
tempAwg.aSecCfg.initPacketJunkSize, MaxSegmentSize,
MaxSegmentSize, ),
err, )
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.aSecCfg.initPacketJunkSize,
MaxSegmentSize,
)
}
} else { } else {
device.awg.aSecCfg.initPacketJunkSize = tempAwg.aSecCfg.initPacketJunkSize device.awg.ASecCfg.InitPacketJunkSize = tempAwg.ASecCfg.InitPacketJunkSize
} }
if tempAwg.aSecCfg.initPacketJunkSize != 0 { if tempAwg.ASecCfg.InitPacketJunkSize != 0 {
isASecOn = true isASecOn = true
} }
if MessageResponseSize+tempAwg.aSecCfg.responsePacketJunkSize >= MaxSegmentSize { if MessageResponseSize+tempAwg.ASecCfg.ResponsePacketJunkSize >= MaxSegmentSize {
if err != nil { errs = append(errs, ipcErrorf(
err = ipcErrorf( ipc.IpcErrorInvalid,
ipc.IpcErrorInvalid, `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, tempAwg.ASecCfg.ResponsePacketJunkSize,
tempAwg.aSecCfg.responsePacketJunkSize, MaxSegmentSize,
MaxSegmentSize, ),
err, )
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.aSecCfg.responsePacketJunkSize,
MaxSegmentSize,
)
}
} else { } else {
device.awg.aSecCfg.responsePacketJunkSize = tempAwg.aSecCfg.responsePacketJunkSize device.awg.ASecCfg.ResponsePacketJunkSize = tempAwg.ASecCfg.ResponsePacketJunkSize
} }
if tempAwg.aSecCfg.responsePacketJunkSize != 0 { if tempAwg.ASecCfg.ResponsePacketJunkSize != 0 {
isASecOn = true isASecOn = true
} }
if tempAwg.aSecCfg.initPacketMagicHeader > 4 { if tempAwg.ASecCfg.InitPacketMagicHeader > 4 {
isASecOn = true isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header") device.log.Verbosef("UAPI: Updating init_packet_magic_header")
device.awg.aSecCfg.initPacketMagicHeader = tempAwg.aSecCfg.initPacketMagicHeader device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader
MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default init type") device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType MessageInitiationType = DefaultMessageInitiationType
} }
if tempAwg.aSecCfg.responsePacketMagicHeader > 4 { if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 {
isASecOn = true isASecOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header") device.log.Verbosef("UAPI: Updating response_packet_magic_header")
device.awg.aSecCfg.responsePacketMagicHeader = tempAwg.aSecCfg.responsePacketMagicHeader device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader
MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default response type") device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType MessageResponseType = DefaultMessageResponseType
} }
if tempAwg.aSecCfg.underloadPacketMagicHeader > 4 { if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 {
isASecOn = true isASecOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header") device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
device.awg.aSecCfg.underloadPacketMagicHeader = tempAwg.aSecCfg.underloadPacketMagicHeader device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader
MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default underload type") device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType MessageCookieReplyType = DefaultMessageCookieReplyType
} }
if tempAwg.aSecCfg.transportPacketMagicHeader > 4 { if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 {
isASecOn = true isASecOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header") device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
device.awg.aSecCfg.transportPacketMagicHeader = tempAwg.aSecCfg.transportPacketMagicHeader device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader
MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader
} else { } else {
device.log.Verbosef("UAPI: Using default transport type") device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType MessageTransportType = DefaultMessageTransportType
@ -787,48 +727,28 @@ func (device *Device) handlePostConfig(tempAwg *awg) (err error) {
// size will be different if same values // size will be different if same values
if len(isSameMap) != 4 { if len(isSameMap) != 4 {
if err != nil { errs = append(errs, ipcErrorf(
err = ipcErrorf( ipc.IpcErrorInvalid,
ipc.IpcErrorInvalid, `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`, MessageInitiationType,
MessageInitiationType, MessageResponseType,
MessageResponseType, MessageCookieReplyType,
MessageCookieReplyType, MessageTransportType,
MessageTransportType, ),
err, )
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
)
}
} }
newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize newInitSize := MessageInitiationSize + device.awg.ASecCfg.InitPacketJunkSize
newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize newResponseSize := MessageResponseSize + device.awg.ASecCfg.ResponsePacketJunkSize
if newInitSize == newResponseSize { if newInitSize == newResponseSize {
if err != nil { errs = append(errs, ipcErrorf(
err = ipcErrorf( ipc.IpcErrorInvalid,
ipc.IpcErrorInvalid, `new init size:%d; and new response size:%d; should differ`,
`new init size:%d; and new response size:%d; should differ; %w`, newInitSize,
newInitSize, newResponseSize,
newResponseSize, ),
err, )
)
} else {
err = ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ`,
newInitSize,
newResponseSize,
)
}
} else { } else {
packetSizeToMsgType = map[int]uint32{ packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType, newInitSize: MessageInitiationType,
@ -838,23 +758,35 @@ func (device *Device) handlePostConfig(tempAwg *awg) (err error) {
} }
msgTypeToJunkSize = map[uint32]int{ msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize, MessageInitiationType: device.awg.ASecCfg.InitPacketJunkSize,
MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize, MessageResponseType: device.awg.ASecCfg.ResponsePacketJunkSize,
MessageCookieReplyType: 0, MessageCookieReplyType: 0,
MessageTransportType: 0, MessageTransportType: 0,
} }
} }
if err := tempAwg.handshakeHandler.Validate(); err == nil { device.awg.IsASecOn.SetTo(isASecOn)
return ipcErrorf(ipc.IpcErrorInvalid, "handle post config foo validate: %w", err) var err error
device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)
if err != nil {
errs = append(errs, err)
} }
device.awg.isASecOn.SetTo(isASecOn)
device.awg.junkCreator, err = NewJunkCreator(device)
device.awg.handshakeHandler = tempAwg.handshakeHandler
// TODO:
device.version = VersionAwgSpecialHandshake
device.awg.aSecMux.Unlock() if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount
device.version = VersionAwgSpecialHandshake
}
} else {
device.version = VersionAwg
}
return err device.awg.ASecMux.Unlock()
return errors.Join(errs...)
} }

View file

@ -1,42 +0,0 @@
package junktag
import (
"fmt"
"strconv"
)
type TaggedJunkGenerator struct {
name string
packetSize int
generators []Generator
}
func newTagedJunkGenerator(name string, size int) TaggedJunkGenerator {
return TaggedJunkGenerator{name: name, generators: make([]Generator, size)}
}
func (tg *TaggedJunkGenerator) append(generator Generator) {
tg.generators = append(tg.generators, generator)
tg.packetSize += generator.Size()
}
func (tg *TaggedJunkGenerator) generate() []byte {
packet := make([]byte, 0, tg.packetSize)
for _, generator := range tg.generators {
packet = append(packet, generator.Generate()...)
}
return packet
}
func (t *TaggedJunkGenerator) nameIndex() (int, error) {
if len(t.name) != 2 {
return 0, fmt.Errorf("name must be 2 character long: %s", t.name)
}
index, err := strconv.Atoi(t.name[1:2])
if err != nil {
return 0, fmt.Errorf("name should be 2 char long: %w", err)
}
return index, nil
}

View file

@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
device.awg.aSecMux.RLock() device.awg.ASecMux.RLock()
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
} }
device.awg.aSecMux.RUnlock() device.awg.ASecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte chainKey [blake2s.Size]byte
) )
device.awg.aSecMux.RLock() device.awg.ASecMux.RLock()
if msg.Type != MessageInitiationType { if msg.Type != MessageInitiationType {
device.awg.aSecMux.RUnlock() device.awg.ASecMux.RUnlock()
return nil return nil
} }
device.awg.aSecMux.RUnlock() device.awg.ASecMux.RUnlock()
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
var msg MessageResponse var msg MessageResponse
device.awg.aSecMux.RLock() device.awg.ASecMux.RLock()
msg.Type = MessageResponseType msg.Type = MessageResponseType
device.awg.aSecMux.RUnlock() device.awg.ASecMux.RUnlock()
msg.Sender = handshake.localIndex msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex msg.Receiver = handshake.remoteIndex
@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
} }
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.awg.aSecMux.RLock() device.awg.ASecMux.RLock()
if msg.Type != MessageResponseType { if msg.Type != MessageResponseType {
device.awg.aSecMux.RUnlock() device.awg.ASecMux.RUnlock()
return nil return nil
} }
device.awg.aSecMux.RUnlock() device.awg.ASecMux.RUnlock()
// lookup handshake by receiver // lookup handshake by receiver

View file

@ -137,7 +137,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
if err == nil { if err == nil {
var totalLen uint64 var totalLen uint64
for _, b := range buffers { for _, b := range buffers {
peer.device.awg.foo.PacketCounter++ peer.device.awg.HandshakeHandler.PacketCounter++
totalLen += uint64(len(b)) totalLen += uint64(len(b))
} }
peer.txBytes.Add(totalLen) peer.txBytes.Add(totalLen)

View file

@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming(
} }
deathSpiral = 0 deathSpiral = 0
device.awg.aSecMux.RLock() device.awg.ASecMux.RLock()
// handle each packet in the batch // handle each packet in the batch
for i, size := range sizes[:count] { for i, size := range sizes[:count] {
if size < MinMessageSize { if size < MinMessageSize {
@ -246,7 +246,7 @@ func (device *Device) RoutineReceiveIncoming(
default: default:
} }
} }
device.awg.aSecMux.RUnlock() device.awg.ASecMux.RUnlock()
for peer, elemsContainer := range elemsByPeer { for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() { if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer peer.queue.inbound.c <- elemsContainer
@ -305,7 +305,7 @@ func (device *Device) RoutineHandshake(id int) {
for elem := range device.queue.handshake.c { for elem := range device.queue.handshake.c {
device.awg.aSecMux.RLock() device.awg.ASecMux.RLock()
// handle cookie fields and ratelimiting // handle cookie fields and ratelimiting
@ -457,7 +457,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive() peer.SendKeepalive()
} }
skip: skip:
device.awg.aSecMux.RUnlock() device.awg.ASecMux.RUnlock()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
} }
} }

View file

@ -128,19 +128,21 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
var junkedHeader []byte var junkedHeader []byte
if peer.device.version >= VersionAwg { if peer.device.version >= VersionAwg {
junks := [][]byte{} var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake { if peer.device.version == VersionAwgSpecialHandshake {
peer.device.awg.aSecMux.RLock() peer.device.awg.ASecMux.RLock()
// set junks depending on packet type // set junks depending on packet type
junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks == nil { if junks == nil {
junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
} }
peer.device.awg.aSecMux.RUnlock() peer.device.awg.ASecMux.RUnlock()
} else {
junks = make([][]byte, peer.device.awg.ASecCfg.JunkPacketCount)
} }
peer.device.awg.aSecMux.RLock() peer.device.awg.ASecMux.RLock()
err := peer.device.awg.junkCreator.createJunkPackets(&junks) err := peer.device.awg.JunkCreator.CreateJunkPackets(junks)
peer.device.awg.aSecMux.RUnlock() peer.device.awg.ASecMux.RUnlock()
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
@ -156,19 +158,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
} }
} }
peer.device.awg.aSecMux.RLock() peer.device.awg.ASecMux.RLock()
if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { if peer.device.awg.ASecCfg.InitPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) buf := make([]byte, 0, peer.device.awg.ASecCfg.InitPacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.InitPacketJunkSize)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
peer.device.awg.aSecMux.RUnlock() peer.device.awg.ASecMux.RUnlock()
return err return err
} }
junkedHeader = writer.Bytes() junkedHeader = writer.Bytes()
} }
peer.device.awg.aSecMux.RUnlock() peer.device.awg.ASecMux.RUnlock()
} }
var buf [MessageInitiationSize]byte var buf [MessageInitiationSize]byte
@ -206,19 +208,19 @@ func (peer *Peer) SendHandshakeResponse() error {
} }
var junkedHeader []byte var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() { if peer.device.isAdvancedSecurityOn() {
peer.device.awg.aSecMux.RLock() peer.device.awg.ASecMux.RLock()
if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { if peer.device.awg.ASecCfg.ResponsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) buf := make([]byte, 0, peer.device.awg.ASecCfg.ResponsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.ResponsePacketJunkSize)
if err != nil { if err != nil {
peer.device.awg.aSecMux.RUnlock() peer.device.awg.ASecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
return err return err
} }
junkedHeader = writer.Bytes() junkedHeader = writer.Bytes()
} }
peer.device.awg.aSecMux.RUnlock() peer.device.awg.ASecMux.RUnlock()
} }
var buf [MessageResponseSize]byte var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])

View file

@ -18,7 +18,7 @@ import (
"sync" "sync"
"time" "time"
junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" "github.com/amnezia-vpn/amneziawg-go/device/awg"
"github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ipc"
) )
@ -99,33 +99,37 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
} }
if device.isAdvancedSecurityOn() { if device.isAdvancedSecurityOn() {
if device.awg.aSecCfg.junkPacketCount != 0 { if device.awg.ASecCfg.JunkPacketCount != 0 {
sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount)
} }
if device.awg.aSecCfg.junkPacketMinSize != 0 { if device.awg.ASecCfg.JunkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize) sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize)
} }
if device.awg.aSecCfg.junkPacketMaxSize != 0 { if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize) sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
} }
if device.awg.aSecCfg.initPacketJunkSize != 0 { if device.awg.ASecCfg.InitPacketJunkSize != 0 {
sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize) sendf("s1=%d", device.awg.ASecCfg.InitPacketJunkSize)
} }
if device.awg.aSecCfg.responsePacketJunkSize != 0 { if device.awg.ASecCfg.ResponsePacketJunkSize != 0 {
sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize) sendf("s2=%d", device.awg.ASecCfg.ResponsePacketJunkSize)
} }
if device.awg.aSecCfg.initPacketMagicHeader != 0 { if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader) sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
} }
if device.awg.aSecCfg.responsePacketMagicHeader != 0 { if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 {
sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader) sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader)
} }
if device.awg.aSecCfg.underloadPacketMagicHeader != 0 { if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 {
sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader) sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader)
} }
if device.awg.aSecCfg.transportPacketMagicHeader != 0 { if device.awg.ASecCfg.TransportPacketMagicHeader != 0 {
sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader) sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
} }
// for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator {
// sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader)
// }
} }
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
@ -181,7 +185,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer) peer := new(ipcSetPeer)
deviceConfig := true deviceConfig := true
tempAwg := awg{} tempAwg := awg.Protocol{}
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@ -238,7 +242,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil return nil
} }
func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error {
switch key { switch key {
case "private_key": case "private_key":
var sk NoisePrivateKey var sk NoisePrivateKey
@ -290,8 +294,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_count") device.log.Verbosef("UAPI: Updating junk_packet_count")
tempAwg.aSecCfg.junkPacketCount = junkPacketCount tempAwg.ASecCfg.JunkPacketCount = junkPacketCount
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "jmin": case "jmin":
junkPacketMinSize, err := strconv.Atoi(value) junkPacketMinSize, err := strconv.Atoi(value)
@ -299,8 +303,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_min_size") device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempAwg.aSecCfg.junkPacketMinSize = junkPacketMinSize tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "jmax": case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value) junkPacketMaxSize, err := strconv.Atoi(value)
@ -308,8 +312,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
} }
device.log.Verbosef("UAPI: Updating junk_packet_max_size") device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempAwg.aSecCfg.junkPacketMaxSize = junkPacketMaxSize tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "s1": case "s1":
initPacketJunkSize, err := strconv.Atoi(value) initPacketJunkSize, err := strconv.Atoi(value)
@ -317,8 +321,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error {
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
} }
device.log.Verbosef("UAPI: Updating init_packet_junk_size") device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempAwg.aSecCfg.initPacketJunkSize = initPacketJunkSize tempAwg.ASecCfg.InitPacketJunkSize = initPacketJunkSize
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "s2": case "s2":
responsePacketJunkSize, err := strconv.Atoi(value) responsePacketJunkSize, err := strconv.Atoi(value)
@ -326,65 +330,65 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error {
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
} }
device.log.Verbosef("UAPI: Updating response_packet_junk_size") device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempAwg.aSecCfg.responsePacketJunkSize = responsePacketJunkSize tempAwg.ASecCfg.ResponsePacketJunkSize = responsePacketJunkSize
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "h1": case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err)
} }
tempAwg.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader)
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "h2": case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err)
} }
tempAwg.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader)
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "h3": case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err)
} }
tempAwg.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "h4": case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err)
} }
tempAwg.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader)
tempAwg.aSecCfg.isSet = true tempAwg.ASecCfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5": case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 { if len(value) == 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key)
} }
generators, err := junktag.Parse(key, value) generators, err := awg.Parse(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
} }
device.log.Verbosef("UAPI: Updating %s", key) device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.handshakeHandler.SpecialJunk.AppendGenerator(generators) tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
tempAwg.aSecCfg.isSet = true tempAwg.HandshakeHandler.IsSet = true
case "j1", "j2", "j3": case "j1", "j2", "j3":
if len(value) == 0 { if len(value) == 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key)
} }
generators, err := junktag.Parse(key, value) generators, err := awg.Parse(key, value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
} }
device.log.Verbosef("UAPI: Updating %s", key) device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.handshakeHandler.ControlledJunk.AppendGenerator(generators) tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators)
tempAwg.aSecCfg.isSet = true tempAwg.HandshakeHandler.IsSet = true
case "itime": case "itime":
itime, err := strconv.ParseInt(value, 10, 64) itime, err := strconv.ParseInt(value, 10, 64)
if err != nil { if err != nil {
@ -392,8 +396,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error {
} }
device.log.Verbosef("UAPI: Updating itime %s", itime) device.log.Verbosef("UAPI: Updating itime %s", itime)
tempAwg.handshakeHandler.ITimeout = time.Duration(itime) tempAwg.HandshakeHandler.ITimeout = time.Duration(itime)
tempAwg.aSecCfg.isSet = true tempAwg.HandshakeHandler.IsSet = true
default: default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
} }

4
go.mod
View file

@ -4,12 +4,12 @@ go 1.24
require ( require (
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
github.com/tevino/abool/v2 v2.1.0 github.com/tevino/abool v1.2.0
golang.org/x/crypto v0.36.0 golang.org/x/crypto v0.36.0
golang.org/x/net v0.37.0 golang.org/x/net v0.37.0
golang.org/x/sys v0.31.0 golang.org/x/sys v0.31.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f
) )
require ( require (

8
go.sum
View file

@ -8,8 +8,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
@ -26,5 +26,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ= gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM=
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=