From 65743536a2a1cfaec6b8b391b82ae5764b5a4a01 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 9 Jun 2025 16:41:54 +0200 Subject: [PATCH] chore: project restructure --- device/awg/awg.go | 30 ++ device/{ => awg}/junk_creator.go | 23 +- device/{ => awg}/junk_creator_test.go | 59 ++-- .../special_handshake_handler.go | 7 +- .../generator.go => awg/tag_generator.go} | 2 +- .../tag_generator_test.go} | 2 +- device/awg/tag_junk_generator.go | 46 +++ .../tag_junk_generator_handler.go} | 16 +- .../junk-tag/parser.go => awg/tag_parser.go} | 18 +- .../parser_test.go => awg/tag_parser_test.go} | 24 +- device/device.go | 290 +++++++----------- .../junk-tag/tagged_junk_generator.go | 42 --- device/noise-protocol.go | 20 +- device/peer.go | 2 +- device/receive.go | 8 +- device/send.go | 42 +-- device/uapi.go | 98 +++--- go.mod | 4 +- go.sum | 8 +- 19 files changed, 356 insertions(+), 385 deletions(-) create mode 100644 device/awg/awg.go rename device/{ => awg}/junk_creator.go (64%) rename device/{ => awg}/junk_creator_test.go (60%) rename device/{internal/junk-tag => awg}/special_handshake_handler.go (90%) rename device/{internal/junk-tag/generator.go => awg/tag_generator.go} (99%) rename device/{internal/junk-tag/generator_test.go => awg/tag_generator_test.go} (99%) create mode 100644 device/awg/tag_junk_generator.go rename device/{internal/junk-tag/tagged_junk_generator_handler.go => awg/tag_junk_generator_handler.go} (63%) rename device/{internal/junk-tag/parser.go => awg/tag_parser.go} (76%) rename device/{internal/junk-tag/parser_test.go => awg/tag_parser_test.go} (66%) delete mode 100644 device/internal/junk-tag/tagged_junk_generator.go diff --git a/device/awg/awg.go b/device/awg/awg.go new file mode 100644 index 0000000..a2da6f1 --- /dev/null +++ b/device/awg/awg.go @@ -0,0 +1,30 @@ +package awg + +import ( + "sync" + + "github.com/tevino/abool" +) + +type Protocol struct { + IsASecOn abool.AtomicBool + // TODO: revision the need of the mutex + ASecMux sync.RWMutex + ASecCfg aSecCfgType + JunkCreator junkCreator + + HandshakeHandler SpecialHandshakeHandler +} + +type aSecCfgType struct { + IsSet bool + JunkPacketCount int + JunkPacketMinSize int + JunkPacketMaxSize int + InitPacketJunkSize int + ResponsePacketJunkSize int + InitPacketMagicHeader uint32 + ResponsePacketMagicHeader uint32 + UnderloadPacketMagicHeader uint32 + TransportPacketMagicHeader uint32 +} diff --git a/device/junk_creator.go b/device/awg/junk_creator.go similarity index 64% rename from device/junk_creator.go rename to device/awg/junk_creator.go index 1df4612..d67eb00 100644 --- a/device/junk_creator.go +++ b/device/awg/junk_creator.go @@ -1,4 +1,4 @@ -package device +package awg import ( "bytes" @@ -8,33 +8,32 @@ import ( ) type junkCreator struct { - device *Device + aSecCfg aSecCfgType cha8Rand *v2.ChaCha8 } -func NewJunkCreator(d *Device) (junkCreator, error) { +func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) { buf := make([]byte, 32) _, err := crand.Read(buf) if err != nil { return junkCreator{}, err } - return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil + return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil } // Should be called with aSecMux RLocked -func (jc *junkCreator) createJunkPackets(junks *[][]byte) error { - if jc.device.awg.aSecCfg.junkPacketCount == 0 { +func (jc *junkCreator) CreateJunkPackets(junks [][]byte) error { + if jc.aSecCfg.JunkPacketCount == 0 { return nil } - *junks = make([][]byte, len(*junks)+jc.device.awg.aSecCfg.junkPacketCount) - for i := range jc.device.awg.aSecCfg.junkPacketCount { + for i := range jc.aSecCfg.JunkPacketCount { packetSize := jc.randomPacketSize() junk, err := jc.randomJunkWithSize(packetSize) if err != nil { return fmt.Errorf("create junk packet: %v", err) } - (*junks)[i] = junk + junks[i] = junk } return nil } @@ -43,13 +42,13 @@ func (jc *junkCreator) createJunkPackets(junks *[][]byte) error { func (jc *junkCreator) randomPacketSize() int { return int( jc.cha8Rand.Uint64()%uint64( - jc.device.awg.aSecCfg.junkPacketMaxSize-jc.device.awg.aSecCfg.junkPacketMinSize, + jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize, ), - ) + jc.device.awg.aSecCfg.junkPacketMinSize + ) + jc.aSecCfg.JunkPacketMinSize } // Should be called with aSecMux RLocked -func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { +func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { headerJunk, err := jc.randomJunkWithSize(size) if err != nil { return fmt.Errorf("create header junk: %v", err) diff --git a/device/junk_creator_test.go b/device/awg/junk_creator_test.go similarity index 60% rename from device/junk_creator_test.go rename to device/awg/junk_creator_test.go index 33a7a23..7ac6cad 100644 --- a/device/junk_creator_test.go +++ b/device/awg/junk_creator_test.go @@ -1,36 +1,27 @@ -package device +package awg import ( "bytes" "fmt" "testing" - - "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" - "github.com/amnezia-vpn/amneziawg-go/tun/tuntest" ) func setUpJunkCreator(t *testing.T) (junkCreator, error) { - cfg, _ := genASecurityConfigs(t) - tun := tuntest.NewChannelTUN() - binds := bindtest.NewChannelBinds() - level := LogLevelVerbose - dev := NewDevice( - tun.TUN(), - binds[0], - NewLogger(level, ""), - ) - - if err := dev.IpcSet(cfg[0]); err != nil { - t.Errorf("failed to configure device %v", err) - dev.Close() - return junkCreator{}, err - } - - jc, err := NewJunkCreator(dev) + jc, err := NewJunkCreator(aSecCfgType{ + IsSet: true, + JunkPacketCount: 5, + JunkPacketMinSize: 500, + JunkPacketMaxSize: 1000, + InitPacketJunkSize: 30, + ResponsePacketJunkSize: 40, + InitPacketMagicHeader: 123456, + ResponsePacketMagicHeader: 67543, + UnderloadPacketMagicHeader: 32345, + TransportPacketMagicHeader: 123123, + }) if err != nil { t.Errorf("failed to create junk creator %v", err) - dev.Close() return junkCreator{}, err } @@ -42,8 +33,9 @@ func Test_junkCreator_createJunkPackets(t *testing.T) { if err != nil { return } - t.Run("", func(t *testing.T) { - got, err := jc.createJunkPackets() + t.Run("valid", func(t *testing.T) { + got := make([][]byte, jc.aSecCfg.JunkPacketCount) + err := jc.CreateJunkPackets(got) if err != nil { t.Errorf( "junkCreator.createJunkPackets() = %v; failed", @@ -68,7 +60,7 @@ func Test_junkCreator_createJunkPackets(t *testing.T) { } func Test_junkCreator_randomJunkWithSize(t *testing.T) { - t.Run("", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { jc, err := setUpJunkCreator(t) if err != nil { return @@ -78,7 +70,6 @@ func Test_junkCreator_randomJunkWithSize(t *testing.T) { fmt.Printf("%v\n%v\n", r1, r2) if bytes.Equal(r1, r2) { t.Errorf("same junks %v", err) - jc.device.Close() return } }) @@ -90,14 +81,14 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { return } for range [30]struct{}{} { - t.Run("", func(t *testing.T) { - if got := jc.randomPacketSize(); jc.device.awg.aSecCfg.junkPacketMinSize > got || - got > jc.device.awg.aSecCfg.junkPacketMaxSize { + t.Run("valid", func(t *testing.T) { + if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got || + got > jc.aSecCfg.JunkPacketMaxSize { t.Errorf( "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", got, - jc.device.awg.aSecCfg.junkPacketMinSize, - jc.device.awg.aSecCfg.junkPacketMaxSize, + jc.aSecCfg.JunkPacketMinSize, + jc.aSecCfg.JunkPacketMaxSize, ) } }) @@ -109,13 +100,13 @@ func Test_junkCreator_appendJunk(t *testing.T) { if err != nil { return } - t.Run("", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { s := "apple" buffer := bytes.NewBuffer([]byte(s)) - err := jc.appendJunk(buffer, 30) + err := jc.AppendJunk(buffer, 30) if err != nil && buffer.Len() != len(s)+30 { - t.Errorf("appendWithJunk() size don't match") + t.Error("appendWithJunk() size don't match") } read := make([]byte, 50) buffer.Read(read) diff --git a/device/internal/junk-tag/special_handshake_handler.go b/device/awg/special_handshake_handler.go similarity index 90% rename from device/internal/junk-tag/special_handshake_handler.go rename to device/awg/special_handshake_handler.go index caad83f..45b048b 100644 --- a/device/internal/junk-tag/special_handshake_handler.go +++ b/device/awg/special_handshake_handler.go @@ -1,4 +1,4 @@ -package junktag +package awg import ( "errors" @@ -6,13 +6,14 @@ import ( ) type SpecialHandshakeHandler struct { - SpecialJunk TaggedJunkGeneratorHandler - ControlledJunk TaggedJunkGeneratorHandler + SpecialJunk TagJunkGeneratorHandler + ControlledJunk TagJunkGeneratorHandler nextItime time.Time ITimeout time.Duration // seconds // TODO: maybe atomic? PacketCounter uint64 + IsSet bool } func (handler *SpecialHandshakeHandler) Validate() error { diff --git a/device/internal/junk-tag/generator.go b/device/awg/tag_generator.go similarity index 99% rename from device/internal/junk-tag/generator.go rename to device/awg/tag_generator.go index 8662655..1974a6e 100644 --- a/device/internal/junk-tag/generator.go +++ b/device/awg/tag_generator.go @@ -1,4 +1,4 @@ -package junktag +package awg import ( crand "crypto/rand" diff --git a/device/internal/junk-tag/generator_test.go b/device/awg/tag_generator_test.go similarity index 99% rename from device/internal/junk-tag/generator_test.go rename to device/awg/tag_generator_test.go index a80e048..a5fc334 100644 --- a/device/internal/junk-tag/generator_test.go +++ b/device/awg/tag_generator_test.go @@ -1,4 +1,4 @@ -package junktag +package awg import ( "fmt" diff --git a/device/awg/tag_junk_generator.go b/device/awg/tag_junk_generator.go new file mode 100644 index 0000000..e856177 --- /dev/null +++ b/device/awg/tag_junk_generator.go @@ -0,0 +1,46 @@ +package awg + +import ( + "fmt" + "strconv" +) + +type TagJunkGenerator struct { + name string + packetSize int + generators []Generator +} + +func newTagJunkGenerator(name string, size int) TagJunkGenerator { + return TagJunkGenerator{name: name, generators: make([]Generator, size)} +} + +func (tg *TagJunkGenerator) append(generator Generator) { + tg.generators = append(tg.generators, generator) + tg.packetSize += generator.Size() +} + +func (tg *TagJunkGenerator) generate() []byte { + packet := make([]byte, 0, tg.packetSize) + for _, generator := range tg.generators { + packet = append(packet, generator.Generate()...) + } + + return packet +} + +func (tg *TagJunkGenerator) Name() string { + return tg.name +} + +func (tg *TagJunkGenerator) nameIndex() (int, error) { + if len(tg.name) != 2 { + return 0, fmt.Errorf("name must be 2 character long: %s", tg.name) + } + + index, err := strconv.Atoi(tg.name[1:2]) + if err != nil { + return 0, fmt.Errorf("name should be 2 char long: %w", err) + } + return index, nil +} diff --git a/device/internal/junk-tag/tagged_junk_generator_handler.go b/device/awg/tag_junk_generator_handler.go similarity index 63% rename from device/internal/junk-tag/tagged_junk_generator_handler.go rename to device/awg/tag_junk_generator_handler.go index d5feb61..a2b0746 100644 --- a/device/internal/junk-tag/tagged_junk_generator_handler.go +++ b/device/awg/tag_junk_generator_handler.go @@ -1,19 +1,21 @@ -package junktag +package awg import "fmt" -type TaggedJunkGeneratorHandler struct { - generators []TaggedJunkGenerator +type TagJunkGeneratorHandler struct { + generators []TagJunkGenerator length int + // Jc + DefaultJunkCount int } -func (handler *TaggedJunkGeneratorHandler) AppendGenerator(generators TaggedJunkGenerator) { +func (handler *TagJunkGeneratorHandler) AppendGenerator(generators TagJunkGenerator) { handler.generators = append(handler.generators, generators) handler.length++ } // validate that packets were defined consecutively -func (handler *TaggedJunkGeneratorHandler) Validate() error { +func (handler *TagJunkGeneratorHandler) Validate() error { seen := make([]bool, len(handler.generators)) for _, generator := range handler.generators { if index, err := generator.nameIndex(); err != nil { @@ -32,8 +34,8 @@ func (handler *TaggedJunkGeneratorHandler) Validate() error { return nil } -func (handler *TaggedJunkGeneratorHandler) Generate() [][]byte { - var rv = make([][]byte, handler.length) +func (handler *TagJunkGeneratorHandler) Generate() [][]byte { + var rv = make([][]byte, handler.length+handler.DefaultJunkCount) for i, generator := range handler.generators { rv[i] = make([]byte, generator.packetSize) copy(rv[i], generator.generate()) diff --git a/device/internal/junk-tag/parser.go b/device/awg/tag_parser.go similarity index 76% rename from device/internal/junk-tag/parser.go rename to device/awg/tag_parser.go index 8230c46..c3254c8 100644 --- a/device/internal/junk-tag/parser.go +++ b/device/awg/tag_parser.go @@ -1,4 +1,4 @@ -package junktag +package awg import ( "fmt" @@ -54,10 +54,10 @@ func parseTag(input string) (Tag, error) { } // TODO: pointernes -func Parse(name, input string) (TaggedJunkGenerator, error) { +func Parse(name, input string) (TagJunkGenerator, error) { inputSlice := strings.Split(input, "<") if len(inputSlice) <= 1 { - return TaggedJunkGenerator{}, fmt.Errorf("empty input: %s", input) + return TagJunkGenerator{}, fmt.Errorf("empty input: %s", input) } uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags)) @@ -65,28 +65,28 @@ func Parse(name, input string) (TaggedJunkGenerator, error) { // skip byproduct of split inputSlice = inputSlice[1:] - rv := newTagedJunkGenerator(name, len(inputSlice)) + rv := newTagJunkGenerator(name, len(inputSlice)) for _, inputParam := range inputSlice { if len(inputParam) <= 1 { - return TaggedJunkGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice) + return TagJunkGenerator{}, fmt.Errorf("empty tag in input: %s", inputSlice) } else if strings.Count(inputParam, ">") != 1 { - return TaggedJunkGenerator{}, fmt.Errorf("ill formated input: %s", input) + return TagJunkGenerator{}, fmt.Errorf("ill formated input: %s", input) } tag, _ := parseTag(inputParam) creator, ok := generatorCreator[tag.Name] if !ok { - return TaggedJunkGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name) + return TagJunkGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name) } if present, ok := uniqueTagCheck[tag.Name]; ok { if present { - return TaggedJunkGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name) + return TagJunkGenerator{}, fmt.Errorf("tag %s needs to be unique", tag.Name) } uniqueTagCheck[tag.Name] = true } generator, err := creator(tag.Param) if err != nil { - return TaggedJunkGenerator{}, fmt.Errorf("gen: %w", err) + return TagJunkGenerator{}, fmt.Errorf("gen: %w", err) } rv.append(generator) diff --git a/device/internal/junk-tag/parser_test.go b/device/awg/tag_parser_test.go similarity index 66% rename from device/internal/junk-tag/parser_test.go rename to device/awg/tag_parser_test.go index dfd3399..8f828ec 100644 --- a/device/internal/junk-tag/parser_test.go +++ b/device/awg/tag_parser_test.go @@ -1,4 +1,4 @@ -package junktag +package awg import ( "fmt" @@ -9,6 +9,7 @@ import ( func TestParse(t *testing.T) { type args struct { + name string input string } tests := []struct { @@ -16,39 +17,44 @@ func TestParse(t *testing.T) { args args wantErr error }{ + { + name: "invalid name", + args: args{name: "apple", input: ""}, + wantErr: fmt.Errorf("ill formated input"), + }, { name: "empty", - args: args{input: ""}, + args: args{name: "i1", input: ""}, wantErr: fmt.Errorf("ill formated input"), }, { name: "extra >", - args: args{input: ">"}, + args: args{name: "i1", input: ">"}, wantErr: fmt.Errorf("ill formated input"), }, { name: "extra <", - args: args{input: "<"}, + args: args{name: "i1", input: "<"}, wantErr: fmt.Errorf("empty tag in input"), }, { name: "empty <>", - args: args{input: "<>"}, + args: args{name: "i1", input: "<>"}, wantErr: fmt.Errorf("empty tag in input"), }, { name: "invalid tag", - args: args{input: ""}, + args: args{name: "i1", input: ""}, wantErr: fmt.Errorf("invalid tag"), }, { name: "counter uniqueness violation", - args: args{input: ""}, + args: args{name: "i1", input: ""}, wantErr: fmt.Errorf("parse tag needs to be unique"), }, { name: "timestamp uniqueness violation", - args: args{input: ""}, + args: args{name: "i1", input: ""}, wantErr: fmt.Errorf("parse tag needs to be unique"), }, { @@ -58,7 +64,7 @@ func TestParse(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := Parse(tt.args.input) + _, err := Parse(tt.args.name, tt.args.input) // TODO: ErrorAs doesn't work as you think if tt.wantErr != nil { diff --git a/device/device.go b/device/device.go index 9ba69dd..7d17ef1 100644 --- a/device/device.go +++ b/device/device.go @@ -6,18 +6,18 @@ package device import ( + "errors" "runtime" "sync" "sync/atomic" "time" "github.com/amnezia-vpn/amneziawg-go/conn" - junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" + "github.com/amnezia-vpn/amneziawg-go/device/awg" "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/rwcancel" "github.com/amnezia-vpn/amneziawg-go/tun" - "github.com/tevino/abool/v2" ) type Version uint8 @@ -129,31 +129,7 @@ type Device struct { log *Logger version Version - awg awg -} - -type awg struct { - isASecOn abool.AtomicBool - // TODO: revision the need of the mutex - aSecMux sync.RWMutex - aSecCfg aSecCfgType - junkCreator junkCreator - - // TODO: determine if it's on - handshakeHandler junktag.SpecialHandshakeHandler -} - -type aSecCfgType struct { - isSet bool - junkPacketCount int - junkPacketMinSize int - junkPacketMaxSize int - initPacketJunkSize int - responsePacketJunkSize int - initPacketMagicHeader uint32 - responsePacketMagicHeader uint32 - underloadPacketMagicHeader uint32 - transportPacketMagicHeader uint32 + awg awg.Protocol } // deviceState represents the state of a Device. @@ -603,7 +579,7 @@ func (device *Device) BindClose() error { return err } func (device *Device) isAdvancedSecurityOn() bool { - return device.awg.isASecOn.IsSet() + return device.awg.IsASecOn.IsSet() } func (device *Device) resetProtocol() { @@ -614,165 +590,129 @@ func (device *Device) resetProtocol() { MessageTransportType = DefaultMessageTransportType } -func (device *Device) handlePostConfig(tempAwg *awg) (err error) { - - if !tempAwg.aSecCfg.isSet { - return err +func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { + if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet { + return nil } + var errs []error + isASecOn := false - device.awg.aSecMux.Lock() - if tempAwg.aSecCfg.junkPacketCount < 0 { - err = ipcErrorf( + device.awg.ASecMux.Lock() + if tempAwg.ASecCfg.JunkPacketCount < 0 { + errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", + ), ) } - device.awg.aSecCfg.junkPacketCount = tempAwg.aSecCfg.junkPacketCount - if tempAwg.aSecCfg.junkPacketCount != 0 { + device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount + if tempAwg.ASecCfg.JunkPacketCount != 0 { isASecOn = true } - device.awg.aSecCfg.junkPacketMinSize = tempAwg.aSecCfg.junkPacketMinSize - if tempAwg.aSecCfg.junkPacketMinSize != 0 { + device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize + if tempAwg.ASecCfg.JunkPacketMinSize != 0 { isASecOn = true } - if device.awg.aSecCfg.junkPacketCount > 0 && - tempAwg.aSecCfg.junkPacketMaxSize == tempAwg.aSecCfg.junkPacketMinSize { + if device.awg.ASecCfg.JunkPacketCount > 0 && + tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize { - tempAwg.aSecCfg.junkPacketMaxSize++ // to make rand gen work + tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work } - if tempAwg.aSecCfg.junkPacketMaxSize >= MaxSegmentSize { - device.awg.aSecCfg.junkPacketMinSize = 0 - device.awg.aSecCfg.junkPacketMaxSize = 1 - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", - tempAwg.aSecCfg.junkPacketMaxSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - tempAwg.aSecCfg.junkPacketMaxSize, - MaxSegmentSize, - ) - } - } else if tempAwg.aSecCfg.junkPacketMaxSize < tempAwg.aSecCfg.junkPacketMinSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "maxSize: %d; should be greater than minSize: %d; %w", - tempAwg.aSecCfg.junkPacketMaxSize, - tempAwg.aSecCfg.junkPacketMinSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "maxSize: %d; should be greater than minSize: %d", - tempAwg.aSecCfg.junkPacketMaxSize, - tempAwg.aSecCfg.junkPacketMinSize, - ) - } + if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize { + device.awg.ASecCfg.JunkPacketMinSize = 0 + device.awg.ASecCfg.JunkPacketMaxSize = 1 + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", + tempAwg.ASecCfg.JunkPacketMaxSize, + MaxSegmentSize, + )) + } else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + "maxSize: %d; should be greater than minSize: %d", + tempAwg.ASecCfg.JunkPacketMaxSize, + tempAwg.ASecCfg.JunkPacketMinSize, + )) } else { - device.awg.aSecCfg.junkPacketMaxSize = tempAwg.aSecCfg.junkPacketMaxSize + device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize } - if tempAwg.aSecCfg.junkPacketMaxSize != 0 { + if tempAwg.ASecCfg.JunkPacketMaxSize != 0 { isASecOn = true } - if MessageInitiationSize+tempAwg.aSecCfg.initPacketJunkSize >= MaxSegmentSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempAwg.aSecCfg.initPacketJunkSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.aSecCfg.initPacketJunkSize, - MaxSegmentSize, - ) - } + if MessageInitiationSize+tempAwg.ASecCfg.InitPacketJunkSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.InitPacketJunkSize, + MaxSegmentSize, + ), + ) } else { - device.awg.aSecCfg.initPacketJunkSize = tempAwg.aSecCfg.initPacketJunkSize + device.awg.ASecCfg.InitPacketJunkSize = tempAwg.ASecCfg.InitPacketJunkSize } - if tempAwg.aSecCfg.initPacketJunkSize != 0 { + if tempAwg.ASecCfg.InitPacketJunkSize != 0 { isASecOn = true } - if MessageResponseSize+tempAwg.aSecCfg.responsePacketJunkSize >= MaxSegmentSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempAwg.aSecCfg.responsePacketJunkSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempAwg.aSecCfg.responsePacketJunkSize, - MaxSegmentSize, - ) - } + if MessageResponseSize+tempAwg.ASecCfg.ResponsePacketJunkSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.ResponsePacketJunkSize, + MaxSegmentSize, + ), + ) } else { - device.awg.aSecCfg.responsePacketJunkSize = tempAwg.aSecCfg.responsePacketJunkSize + device.awg.ASecCfg.ResponsePacketJunkSize = tempAwg.ASecCfg.ResponsePacketJunkSize } - if tempAwg.aSecCfg.responsePacketJunkSize != 0 { + if tempAwg.ASecCfg.ResponsePacketJunkSize != 0 { isASecOn = true } - if tempAwg.aSecCfg.initPacketMagicHeader > 4 { + if tempAwg.ASecCfg.InitPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.awg.aSecCfg.initPacketMagicHeader = tempAwg.aSecCfg.initPacketMagicHeader - MessageInitiationType = device.awg.aSecCfg.initPacketMagicHeader + device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader + MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") MessageInitiationType = DefaultMessageInitiationType } - if tempAwg.aSecCfg.responsePacketMagicHeader > 4 { + if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.awg.aSecCfg.responsePacketMagicHeader = tempAwg.aSecCfg.responsePacketMagicHeader - MessageResponseType = device.awg.aSecCfg.responsePacketMagicHeader + device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader + MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") MessageResponseType = DefaultMessageResponseType } - if tempAwg.aSecCfg.underloadPacketMagicHeader > 4 { + if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.awg.aSecCfg.underloadPacketMagicHeader = tempAwg.aSecCfg.underloadPacketMagicHeader - MessageCookieReplyType = device.awg.aSecCfg.underloadPacketMagicHeader + device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader + MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") MessageCookieReplyType = DefaultMessageCookieReplyType } - if tempAwg.aSecCfg.transportPacketMagicHeader > 4 { + if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.awg.aSecCfg.transportPacketMagicHeader = tempAwg.aSecCfg.transportPacketMagicHeader - MessageTransportType = device.awg.aSecCfg.transportPacketMagicHeader + device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader + MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") MessageTransportType = DefaultMessageTransportType @@ -787,48 +727,28 @@ func (device *Device) handlePostConfig(tempAwg *awg) (err error) { // size will be different if same values if len(isSameMap) != 4 { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - ) - } + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, + MessageInitiationType, + MessageResponseType, + MessageCookieReplyType, + MessageTransportType, + ), + ) } - newInitSize := MessageInitiationSize + device.awg.aSecCfg.initPacketJunkSize - newResponseSize := MessageResponseSize + device.awg.aSecCfg.responsePacketJunkSize + newInitSize := MessageInitiationSize + device.awg.ASecCfg.InitPacketJunkSize + newResponseSize := MessageResponseSize + device.awg.ASecCfg.ResponsePacketJunkSize if newInitSize == newResponseSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ; %w`, - newInitSize, - newResponseSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ`, - newInitSize, - newResponseSize, - ) - } + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `new init size:%d; and new response size:%d; should differ`, + newInitSize, + newResponseSize, + ), + ) } else { packetSizeToMsgType = map[int]uint32{ newInitSize: MessageInitiationType, @@ -838,23 +758,35 @@ func (device *Device) handlePostConfig(tempAwg *awg) (err error) { } msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.awg.aSecCfg.initPacketJunkSize, - MessageResponseType: device.awg.aSecCfg.responsePacketJunkSize, + MessageInitiationType: device.awg.ASecCfg.InitPacketJunkSize, + MessageResponseType: device.awg.ASecCfg.ResponsePacketJunkSize, MessageCookieReplyType: 0, MessageTransportType: 0, } } - if err := tempAwg.handshakeHandler.Validate(); err == nil { - return ipcErrorf(ipc.IpcErrorInvalid, "handle post config foo validate: %w", err) + device.awg.IsASecOn.SetTo(isASecOn) + var err error + device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg) + if err != nil { + errs = append(errs, err) } - device.awg.isASecOn.SetTo(isASecOn) - device.awg.junkCreator, err = NewJunkCreator(device) - device.awg.handshakeHandler = tempAwg.handshakeHandler - // TODO: - device.version = VersionAwgSpecialHandshake - device.awg.aSecMux.Unlock() + if tempAwg.HandshakeHandler.IsSet { + if err := tempAwg.HandshakeHandler.Validate(); tempAwg.HandshakeHandler.IsSet && err != nil { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, "handshake handler validate: %w", err)) + } else { + device.awg.HandshakeHandler = tempAwg.HandshakeHandler + device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.version = VersionAwgSpecialHandshake + } + } else { + device.version = VersionAwg + } - return err + device.awg.ASecMux.Unlock() + + return errors.Join(errs...) } diff --git a/device/internal/junk-tag/tagged_junk_generator.go b/device/internal/junk-tag/tagged_junk_generator.go deleted file mode 100644 index a89edac..0000000 --- a/device/internal/junk-tag/tagged_junk_generator.go +++ /dev/null @@ -1,42 +0,0 @@ -package junktag - -import ( - "fmt" - "strconv" -) - -type TaggedJunkGenerator struct { - name string - packetSize int - generators []Generator -} - -func newTagedJunkGenerator(name string, size int) TaggedJunkGenerator { - return TaggedJunkGenerator{name: name, generators: make([]Generator, size)} -} - -func (tg *TaggedJunkGenerator) append(generator Generator) { - tg.generators = append(tg.generators, generator) - tg.packetSize += generator.Size() -} - -func (tg *TaggedJunkGenerator) generate() []byte { - packet := make([]byte, 0, tg.packetSize) - for _, generator := range tg.generators { - packet = append(packet, generator.Generate()...) - } - - return packet -} - -func (t *TaggedJunkGenerator) nameIndex() (int, error) { - if len(t.name) != 2 { - return 0, fmt.Errorf("name must be 2 character long: %s", t.name) - } - - index, err := strconv.Atoi(t.name[1:2]) - if err != nil { - return 0, fmt.Errorf("name should be 2 char long: %w", err) - } - return index, nil -} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index d774904..89c634c 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -204,12 +204,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -263,12 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() if msg.Type != MessageInitiationType { - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() return nil } - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -383,9 +383,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() msg.Type = MessageResponseType - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -435,12 +435,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() if msg.Type != MessageResponseType { - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() return nil } - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() // lookup handshake by receiver diff --git a/device/peer.go b/device/peer.go index bdc52fc..9409ca6 100644 --- a/device/peer.go +++ b/device/peer.go @@ -137,7 +137,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { if err == nil { var totalLen uint64 for _, b := range buffers { - peer.device.awg.foo.PacketCounter++ + peer.device.awg.HandshakeHandler.PacketCounter++ totalLen += uint64(len(b)) } peer.txBytes.Add(totalLen) diff --git a/device/receive.go b/device/receive.go index f235a33..1875c76 100644 --- a/device/receive.go +++ b/device/receive.go @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -246,7 +246,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -305,7 +305,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.awg.aSecMux.RLock() + device.awg.ASecMux.RLock() // handle cookie fields and ratelimiting @@ -457,7 +457,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.awg.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index f49b0e1..2e446cd 100644 --- a/device/send.go +++ b/device/send.go @@ -128,19 +128,21 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { var junkedHeader []byte if peer.device.version >= VersionAwg { - junks := [][]byte{} + var junks [][]byte if peer.device.version == VersionAwgSpecialHandshake { - peer.device.awg.aSecMux.RLock() + peer.device.awg.ASecMux.RLock() // set junks depending on packet type - junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() + junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() if junks == nil { - junks = peer.device.awg.handshakeHandler.GenerateSpecialJunk() + junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() + } else { + junks = make([][]byte, peer.device.awg.ASecCfg.JunkPacketCount) } - peer.device.awg.aSecMux.RLock() - err := peer.device.awg.junkCreator.createJunkPackets(&junks) - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RLock() + err := peer.device.awg.JunkCreator.CreateJunkPackets(junks) + peer.device.awg.ASecMux.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -156,19 +158,19 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.awg.aSecMux.RLock() - if peer.device.awg.aSecCfg.initPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.awg.aSecCfg.initPacketJunkSize) + peer.device.awg.ASecMux.RLock() + if peer.device.awg.ASecCfg.InitPacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.ASecCfg.InitPacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.initPacketJunkSize) + err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.InitPacketJunkSize) if err != nil { peer.device.log.Errorf("%v - %v", peer, err) - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() return err } junkedHeader = writer.Bytes() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -206,19 +208,19 @@ func (peer *Peer) SendHandshakeResponse() error { } var junkedHeader []byte if peer.device.isAdvancedSecurityOn() { - peer.device.awg.aSecMux.RLock() - if peer.device.awg.aSecCfg.responsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.awg.aSecCfg.responsePacketJunkSize) + peer.device.awg.ASecMux.RLock() + if peer.device.awg.ASecCfg.ResponsePacketJunkSize != 0 { + buf := make([]byte, 0, peer.device.awg.ASecCfg.ResponsePacketJunkSize) writer := bytes.NewBuffer(buf[:0]) - err = peer.device.awg.junkCreator.appendJunk(writer, peer.device.awg.aSecCfg.responsePacketJunkSize) + err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.ResponsePacketJunkSize) if err != nil { - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() peer.device.log.Errorf("%v - %v", peer, err) return err } junkedHeader = writer.Bytes() } - peer.device.awg.aSecMux.RUnlock() + peer.device.awg.ASecMux.RUnlock() } var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) diff --git a/device/uapi.go b/device/uapi.go index 7ae4b1c..216f7cd 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,7 +18,7 @@ import ( "sync" "time" - junktag "github.com/amnezia-vpn/amneziawg-go/device/internal/junk-tag" + "github.com/amnezia-vpn/amneziawg-go/device/awg" "github.com/amnezia-vpn/amneziawg-go/ipc" ) @@ -99,33 +99,37 @@ func (device *Device) IpcGetOperation(w io.Writer) error { } if device.isAdvancedSecurityOn() { - if device.awg.aSecCfg.junkPacketCount != 0 { - sendf("jc=%d", device.awg.aSecCfg.junkPacketCount) + if device.awg.ASecCfg.JunkPacketCount != 0 { + sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount) } - if device.awg.aSecCfg.junkPacketMinSize != 0 { - sendf("jmin=%d", device.awg.aSecCfg.junkPacketMinSize) + if device.awg.ASecCfg.JunkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize) } - if device.awg.aSecCfg.junkPacketMaxSize != 0 { - sendf("jmax=%d", device.awg.aSecCfg.junkPacketMaxSize) + if device.awg.ASecCfg.JunkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize) } - if device.awg.aSecCfg.initPacketJunkSize != 0 { - sendf("s1=%d", device.awg.aSecCfg.initPacketJunkSize) + if device.awg.ASecCfg.InitPacketJunkSize != 0 { + sendf("s1=%d", device.awg.ASecCfg.InitPacketJunkSize) } - if device.awg.aSecCfg.responsePacketJunkSize != 0 { - sendf("s2=%d", device.awg.aSecCfg.responsePacketJunkSize) + if device.awg.ASecCfg.ResponsePacketJunkSize != 0 { + sendf("s2=%d", device.awg.ASecCfg.ResponsePacketJunkSize) } - if device.awg.aSecCfg.initPacketMagicHeader != 0 { - sendf("h1=%d", device.awg.aSecCfg.initPacketMagicHeader) + if device.awg.ASecCfg.InitPacketMagicHeader != 0 { + sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) } - if device.awg.aSecCfg.responsePacketMagicHeader != 0 { - sendf("h2=%d", device.awg.aSecCfg.responsePacketMagicHeader) + if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { + sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) } - if device.awg.aSecCfg.underloadPacketMagicHeader != 0 { - sendf("h3=%d", device.awg.aSecCfg.underloadPacketMagicHeader) + if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 { + sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) } - if device.awg.aSecCfg.transportPacketMagicHeader != 0 { - sendf("h4=%d", device.awg.aSecCfg.transportPacketMagicHeader) + if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { + sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) } + + // for _, generator := range device.awg.HandshakeHandler.ControlledJunk.AppendGenerator { + // sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) + // } } for _, peer := range device.peers.keyMap { @@ -181,7 +185,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempAwg := awg{} + tempAwg := awg.Protocol{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() @@ -238,7 +242,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { +func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error { switch key { case "private_key": var sk NoisePrivateKey @@ -290,8 +294,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempAwg.aSecCfg.junkPacketCount = junkPacketCount - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketCount = junkPacketCount + tempAwg.ASecCfg.IsSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) @@ -299,8 +303,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempAwg.aSecCfg.junkPacketMinSize = junkPacketMinSize - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize + tempAwg.ASecCfg.IsSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) @@ -308,8 +312,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempAwg.aSecCfg.junkPacketMaxSize = junkPacketMaxSize - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize + tempAwg.ASecCfg.IsSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) @@ -317,8 +321,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempAwg.aSecCfg.initPacketJunkSize = initPacketJunkSize - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.InitPacketJunkSize = initPacketJunkSize + tempAwg.ASecCfg.IsSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) @@ -326,65 +330,65 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempAwg.aSecCfg.responsePacketJunkSize = responsePacketJunkSize - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.ResponsePacketJunkSize = responsePacketJunkSize + tempAwg.ASecCfg.IsSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) } - tempAwg.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) } - tempAwg.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) } - tempAwg.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) } - tempAwg.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempAwg.aSecCfg.isSet = true + tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "i1", "i2", "i3", "i4", "i5": if len(value) == 0 { return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) } - generators, err := junktag.Parse(key, value) + generators, err := awg.Parse(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) } device.log.Verbosef("UAPI: Updating %s", key) - tempAwg.handshakeHandler.SpecialJunk.AppendGenerator(generators) - tempAwg.aSecCfg.isSet = true + tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators) + tempAwg.HandshakeHandler.IsSet = true case "j1", "j2", "j3": if len(value) == 0 { return ipcErrorf(ipc.IpcErrorInvalid, "%s should be non null", key) } - generators, err := junktag.Parse(key, value) + generators, err := awg.Parse(key, value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) } device.log.Verbosef("UAPI: Updating %s", key) - tempAwg.handshakeHandler.ControlledJunk.AppendGenerator(generators) - tempAwg.aSecCfg.isSet = true + tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators) + tempAwg.HandshakeHandler.IsSet = true case "itime": itime, err := strconv.ParseInt(value, 10, 64) if err != nil { @@ -392,8 +396,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg) error { } device.log.Verbosef("UAPI: Updating itime %s", itime) - tempAwg.handshakeHandler.ITimeout = time.Duration(itime) - tempAwg.aSecCfg.isSet = true + tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) + tempAwg.HandshakeHandler.IsSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } diff --git a/go.mod b/go.mod index c150a4b..5896772 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,12 @@ go 1.24 require ( github.com/stretchr/testify v1.10.0 - github.com/tevino/abool/v2 v2.1.0 + github.com/tevino/abool v1.2.0 golang.org/x/crypto v0.36.0 golang.org/x/net v0.37.0 golang.org/x/sys v0.31.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 + gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f ) require ( diff --git a/go.sum b/go.sum index ec3fa7f..4b1e64a 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= -github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= +github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA= +github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= @@ -26,5 +26,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ= -gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= +gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM= +gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=