diff --git a/Dockerfile b/Dockerfile index 12159be..6d60440 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.24 as awg +FROM golang:1.24.4 as awg COPY . /awg WORKDIR /awg RUN go mod download && \ @@ -7,10 +7,24 @@ RUN go mod download && \ FROM alpine:3.19 ARG AWGTOOLS_RELEASE="1.0.20241018" + +RUN apk add linux-headers build-base +COPY awg-tools /awg-tools +RUN pwd && ls -la / && ls -la /awg-tools +WORKDIR /awg-tools/src +# RUN ls -la && pwd && ls awg-tools +RUN make +RUN mkdir -p build && \ + cp wg ./build/awg && \ + cp wg-quick/linux.bash ./build/awg-quick + +RUN cp build/awg /usr/bin/awg +RUN cp build/awg-quick /usr/bin/awg-quick + RUN apk --no-cache add iproute2 iptables bash && \ cd /usr/bin/ && \ - wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \ - unzip -j alpine-3.19-amneziawg-tools.zip && \ + # wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \ + # unzip -j alpine-3.19-amneziawg-tools.zip && \ chmod +x /usr/bin/awg /usr/bin/awg-quick && \ ln -s /usr/bin/awg /usr/bin/wg && \ ln -s /usr/bin/awg-quick /usr/bin/wg-quick diff --git a/device/awg/awg.go b/device/awg/awg.go new file mode 100644 index 0000000..fd5a96d --- /dev/null +++ b/device/awg/awg.go @@ -0,0 +1,144 @@ +package awg + +import ( + "bytes" + "fmt" + "slices" + "strconv" + "strings" + "sync" + + "github.com/tevino/abool" +) + +type aSecCfgType struct { + IsSet bool + JunkPacketCount int + JunkPacketMinSize int + JunkPacketMaxSize int + InitHeaderJunkSize int + ResponseHeaderJunkSize int + CookieReplyHeaderJunkSize int + TransportHeaderJunkSize int + InitPacketMagicHeader uint32 + ResponsePacketMagicHeader uint32 + UnderloadPacketMagicHeader uint32 + TransportPacketMagicHeader uint32 + // InitPacketMagicHeader Limit + // ResponsePacketMagicHeader Limit + // UnderloadPacketMagicHeader Limit + // TransportPacketMagicHeader Limit +} + +type Limit struct { + Min uint32 + Max uint32 + HeaderType uint32 +} + +func NewLimit(min, max, headerType uint32) (Limit, error) { + if min > max { + return Limit{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max) + } + + return Limit{ + Min: min, + Max: max, + HeaderType: headerType, + }, nil +} + +func ParseMagicHeader(key, value string, defaultHeaderType uint32) (Limit, error) { + // tempAwg.ASecCfg.InitPacketMagicHeader, err = awg.NewLimit(uint32(initPacketMagicHeaderMin), uint32(initPacketMagicHeaderMax), DNewLimit(min, max, headerType)efaultMessageInitiationType) + // var min, max, headerType uint32 + // _, err := fmt.Sscanf(value, "%d-%d:%d", &min, &max, &headerType) + // if err != nil { + // return Limit{}, fmt.Errorf("invalid magic header format: %s", value) + // } + + limits := strings.Split(value, "-") + if len(limits) != 2 { + return Limit{}, fmt.Errorf("invalid format for key: %s; %s", key, value) + } + + min, err := strconv.ParseUint(limits[0], 10, 32) + if err != nil { + return Limit{}, fmt.Errorf("parse min key: %s; value: ; %w", key, limits[0], err) + } + + max, err := strconv.ParseUint(limits[1], 10, 32) + if err != nil { + return Limit{}, fmt.Errorf("parse max key: %s; value: ; %w", key, limits[0], err) + } + + limit, err := NewLimit(uint32(min), uint32(max), defaultHeaderType) + if err != nil { + return Limit{}, fmt.Errorf("new lmit key: %s; value: ; %w", key, limits[0], err) + } + + return limit, nil +} + +type Limits []Limit + +func NewLimits(limits []Limit) Limits { + slices.SortFunc(limits, func(a, b Limit) int { + if a.Min < b.Min { + return -1 + } else if a.Min > b.Min { + return 1 + } + return 0 + }) + + return Limits(limits) +} + +type Protocol struct { + IsASecOn abool.AtomicBool + // TODO: revision the need of the mutex + ASecMux sync.RWMutex + ASecCfg aSecCfgType + JunkCreator junkCreator + + HandshakeHandler SpecialHandshakeHandler +} + +func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize) +} + +func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize) +} + +func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize) +} + +func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) { + return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize, packetSize) +} + +func (protocol *Protocol) createHeaderJunk(junkSize int, optExtraSize ...int) ([]byte, error) { + extraSize := 0 + if len(optExtraSize) == 1 { + extraSize = optExtraSize[0] + } + + var junk []byte + protocol.ASecMux.RLock() + if junkSize != 0 { + buf := make([]byte, 0, junkSize+extraSize) + writer := bytes.NewBuffer(buf[:0]) + err := protocol.JunkCreator.AppendJunk(writer, junkSize) + if err != nil { + protocol.ASecMux.RUnlock() + return nil, err + } + junk = writer.Bytes() + } + protocol.ASecMux.RUnlock() + + return junk, nil +} diff --git a/device/awg/internal/mock.go b/device/awg/internal/mock.go new file mode 100644 index 0000000..a2e1c95 --- /dev/null +++ b/device/awg/internal/mock.go @@ -0,0 +1,37 @@ +package internal + +type mockGenerator struct { + size int +} + +func NewMockGenerator(size int) mockGenerator { + return mockGenerator{size: size} +} + +func (m mockGenerator) Generate() []byte { + return make([]byte, m.size) +} + +func (m mockGenerator) Size() int { + return m.size +} + +func (m mockGenerator) Name() string { + return "mock" +} + +type mockByteGenerator struct { + data []byte +} + +func NewMockByteGenerator(data []byte) mockByteGenerator { + return mockByteGenerator{data: data} +} + +func (bg mockByteGenerator) Generate() []byte { + return bg.data +} + +func (bg mockByteGenerator) Size() int { + return len(bg.data) +} diff --git a/device/junk_creator.go b/device/awg/junk_creator.go similarity index 52% rename from device/junk_creator.go rename to device/awg/junk_creator.go index 3a2d3b4..91fd253 100644 --- a/device/junk_creator.go +++ b/device/awg/junk_creator.go @@ -1,4 +1,4 @@ -package device +package awg import ( "bytes" @@ -8,61 +8,62 @@ import ( ) type junkCreator struct { - device *Device + aSecCfg aSecCfgType cha8Rand *v2.ChaCha8 } -func NewJunkCreator(d *Device) (junkCreator, error) { +// TODO: refactor param to only pass the junk related params +func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) { buf := make([]byte, 32) _, err := crand.Read(buf) if err != nil { return junkCreator{}, err } - return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil + return junkCreator{aSecCfg: aSecCfg, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil } // Should be called with aSecMux RLocked -func (jc *junkCreator) createJunkPackets() ([][]byte, error) { - if jc.device.aSecCfg.junkPacketCount == 0 { - return nil, nil +func (jc *junkCreator) CreateJunkPackets(junks *[][]byte) error { + if jc.aSecCfg.JunkPacketCount == 0 { + return nil } - junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) - for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { + for range jc.aSecCfg.JunkPacketCount { packetSize := jc.randomPacketSize() junk, err := jc.randomJunkWithSize(packetSize) if err != nil { - return nil, fmt.Errorf("Failed to create junk packet: %v", err) + return fmt.Errorf("create junk packet: %v", err) } - junks = append(junks, junk) + *junks = append(*junks, junk) } - return junks, nil + return nil } // Should be called with aSecMux RLocked func (jc *junkCreator) randomPacketSize() int { return int( jc.cha8Rand.Uint64()%uint64( - jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, + jc.aSecCfg.JunkPacketMaxSize-jc.aSecCfg.JunkPacketMinSize, ), - ) + jc.device.aSecCfg.junkPacketMinSize + ) + jc.aSecCfg.JunkPacketMinSize } // Should be called with aSecMux RLocked -func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { +func (jc *junkCreator) AppendJunk(writer *bytes.Buffer, size int) error { headerJunk, err := jc.randomJunkWithSize(size) if err != nil { - return fmt.Errorf("failed to create header junk: %v", err) + return fmt.Errorf("create header junk: %v", err) } _, err = writer.Write(headerJunk) if err != nil { - return fmt.Errorf("failed to write header junk: %v", err) + return fmt.Errorf("write header junk: %v", err) } return nil } // Should be called with aSecMux RLocked func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { + // TODO: use a memory pool to allocate junk := make([]byte, size) _, err := jc.cha8Rand.Read(junk) return junk, err diff --git a/device/junk_creator_test.go b/device/awg/junk_creator_test.go similarity index 61% rename from device/junk_creator_test.go rename to device/awg/junk_creator_test.go index d3cf2b3..424f104 100644 --- a/device/junk_creator_test.go +++ b/device/awg/junk_creator_test.go @@ -1,36 +1,27 @@ -package device +package awg import ( "bytes" "fmt" "testing" - - "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" - "github.com/amnezia-vpn/amneziawg-go/tun/tuntest" ) func setUpJunkCreator(t *testing.T) (junkCreator, error) { - cfg, _ := genASecurityConfigs(t) - tun := tuntest.NewChannelTUN() - binds := bindtest.NewChannelBinds() - level := LogLevelVerbose - dev := NewDevice( - tun.TUN(), - binds[0], - NewLogger(level, ""), - ) - - if err := dev.IpcSet(cfg[0]); err != nil { - t.Errorf("failed to configure device %v", err) - dev.Close() - return junkCreator{}, err - } - - jc, err := NewJunkCreator(dev) + jc, err := NewJunkCreator(aSecCfgType{ + IsSet: true, + JunkPacketCount: 5, + JunkPacketMinSize: 500, + JunkPacketMaxSize: 1000, + InitHeaderJunkSize: 30, + ResponseHeaderJunkSize: 40, + InitPacketMagicHeader: 123456, + ResponsePacketMagicHeader: 67543, + UnderloadPacketMagicHeader: 32345, + TransportPacketMagicHeader: 123123, + }) if err != nil { t.Errorf("failed to create junk creator %v", err) - dev.Close() return junkCreator{}, err } @@ -42,8 +33,9 @@ func Test_junkCreator_createJunkPackets(t *testing.T) { if err != nil { return } - t.Run("", func(t *testing.T) { - got, err := jc.createJunkPackets() + t.Run("valid", func(t *testing.T) { + got := make([][]byte, 0, jc.aSecCfg.JunkPacketCount) + err := jc.CreateJunkPackets(&got) if err != nil { t.Errorf( "junkCreator.createJunkPackets() = %v; failed", @@ -68,7 +60,7 @@ func Test_junkCreator_createJunkPackets(t *testing.T) { } func Test_junkCreator_randomJunkWithSize(t *testing.T) { - t.Run("", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { jc, err := setUpJunkCreator(t) if err != nil { return @@ -78,7 +70,6 @@ func Test_junkCreator_randomJunkWithSize(t *testing.T) { fmt.Printf("%v\n%v\n", r1, r2) if bytes.Equal(r1, r2) { t.Errorf("same junks %v", err) - jc.device.Close() return } }) @@ -90,14 +81,14 @@ func Test_junkCreator_randomPacketSize(t *testing.T) { return } for range [30]struct{}{} { - t.Run("", func(t *testing.T) { - if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || - got > jc.device.aSecCfg.junkPacketMaxSize { + t.Run("valid", func(t *testing.T) { + if got := jc.randomPacketSize(); jc.aSecCfg.JunkPacketMinSize > got || + got > jc.aSecCfg.JunkPacketMaxSize { t.Errorf( "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", got, - jc.device.aSecCfg.junkPacketMinSize, - jc.device.aSecCfg.junkPacketMaxSize, + jc.aSecCfg.JunkPacketMinSize, + jc.aSecCfg.JunkPacketMaxSize, ) } }) @@ -109,13 +100,13 @@ func Test_junkCreator_appendJunk(t *testing.T) { if err != nil { return } - t.Run("", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { s := "apple" buffer := bytes.NewBuffer([]byte(s)) - err := jc.appendJunk(buffer, 30) + err := jc.AppendJunk(buffer, 30) if err != nil && buffer.Len() != len(s)+30 { - t.Errorf("appendWithJunk() size don't match") + t.Error("appendWithJunk() size don't match") } read := make([]byte, 50) buffer.Read(read) diff --git a/device/awg/special_handshake_handler.go b/device/awg/special_handshake_handler.go new file mode 100644 index 0000000..e582d97 --- /dev/null +++ b/device/awg/special_handshake_handler.go @@ -0,0 +1,73 @@ +package awg + +import ( + "errors" + "time" + + "github.com/tevino/abool" + "go.uber.org/atomic" +) + +// TODO: atomic?/ and better way to use this +var PacketCounter *atomic.Uint64 = atomic.NewUint64(0) + +// TODO +var WaitResponse = struct { + Channel chan struct{} + ShouldWait *abool.AtomicBool +}{ + make(chan struct{}, 1), + abool.New(), +} + +type SpecialHandshakeHandler struct { + isFirstDone bool + SpecialJunk TagJunkPacketGenerators + ControlledJunk TagJunkPacketGenerators + + nextItime time.Time + ITimeout time.Duration // seconds + + IsSet bool +} + +func (handler *SpecialHandshakeHandler) Validate() error { + var errs []error + if err := handler.SpecialJunk.Validate(); err != nil { + errs = append(errs, err) + } + if err := handler.ControlledJunk.Validate(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +} + +func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte { + if !handler.SpecialJunk.IsDefined() { + return nil + } + + // TODO: create tests + if !handler.isFirstDone { + handler.isFirstDone = true + } else if !handler.isTimeToSendSpecial() { + return nil + } + + rv := handler.SpecialJunk.GeneratePackets() + handler.nextItime = time.Now().Add(handler.ITimeout) + + return rv +} + +func (handler *SpecialHandshakeHandler) isTimeToSendSpecial() bool { + return time.Now().After(handler.nextItime) +} + +func (handler *SpecialHandshakeHandler) GenerateControlledJunk() [][]byte { + if !handler.ControlledJunk.IsDefined() { + return nil + } + + return handler.ControlledJunk.GeneratePackets() +} diff --git a/device/awg/tag_generator.go b/device/awg/tag_generator.go new file mode 100644 index 0000000..65d8004 --- /dev/null +++ b/device/awg/tag_generator.go @@ -0,0 +1,190 @@ +package awg + +import ( + crand "crypto/rand" + "encoding/binary" + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + + v2 "math/rand/v2" + // "go.uber.org/atomic" +) + +type Generator interface { + Generate() []byte + Size() int +} + +type newGenerator func(string) (Generator, error) + +type BytesGenerator struct { + value []byte + size int +} + +func (bg *BytesGenerator) Generate() []byte { + return bg.value +} + +func (bg *BytesGenerator) Size() int { + return bg.size +} + +func newBytesGenerator(param string) (Generator, error) { + hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X") + if !hasPrefix { + return nil, fmt.Errorf("not correct hex: %s", param) + } + + hex, err := hexToBytes(param) + if err != nil { + return nil, fmt.Errorf("hexToBytes: %w", err) + } + + return &BytesGenerator{value: hex, size: len(hex)}, nil +} + +func hexToBytes(hexStr string) ([]byte, error) { + hexStr = strings.TrimPrefix(hexStr, "0x") + hexStr = strings.TrimPrefix(hexStr, "0X") + + // Ensure even length (pad with leading zero if needed) + if len(hexStr)%2 != 0 { + hexStr = "0" + hexStr + } + + return hex.DecodeString(hexStr) +} + +type RandomPacketGenerator struct { + cha8Rand *v2.ChaCha8 + size int +} + +func (rpg *RandomPacketGenerator) Generate() []byte { + junk := make([]byte, rpg.size) + rpg.cha8Rand.Read(junk) + return junk +} + +func (rpg *RandomPacketGenerator) Size() int { + return rpg.size +} + +func newRandomPacketGenerator(param string) (Generator, error) { + size, err := strconv.Atoi(param) + if err != nil { + return nil, fmt.Errorf("random packet parse int: %w", err) + } + + if size > 1000 { + return nil, fmt.Errorf("random packet size must be less than 1000") + } + + buf := make([]byte, 32) + _, err = crand.Read(buf) + if err != nil { + return nil, fmt.Errorf("random packet crand read: %w", err) + } + + return &RandomPacketGenerator{ + cha8Rand: v2.NewChaCha8([32]byte(buf)), + size: size, + }, nil +} + +type TimestampGenerator struct { +} + +func (tg *TimestampGenerator) Generate() []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix())) + return buf +} + +func (tg *TimestampGenerator) Size() int { + return 8 +} + +func newTimestampGenerator(param string) (Generator, error) { + if len(param) != 0 { + return nil, fmt.Errorf("timestamp param needs to be empty: %s", param) + } + + return &TimestampGenerator{}, nil +} + +type WaitTimeoutGenerator struct { + waitTimeout time.Duration +} + +func (wtg *WaitTimeoutGenerator) Generate() []byte { + time.Sleep(wtg.waitTimeout) + return []byte{} +} + +func (wtg *WaitTimeoutGenerator) Size() int { + return 0 +} + +func newWaitTimeoutGenerator(param string) (Generator, error) { + timeout, err := strconv.Atoi(param) + if err != nil { + return nil, fmt.Errorf("timeout parse int: %w", err) + } + + if timeout > 5000 { + return nil, fmt.Errorf("timeout must be less than 5000ms") + } + + return &WaitTimeoutGenerator{ + waitTimeout: time.Duration(timeout) * time.Millisecond, + }, nil +} + +type PacketCounterGenerator struct { +} + +func (c *PacketCounterGenerator) Generate() []byte { + buf := make([]byte, 8) + // TODO: better way to handle counter tag + binary.BigEndian.PutUint64(buf, PacketCounter.Load()) + return buf +} + +func (c *PacketCounterGenerator) Size() int { + return 8 +} + +func newPacketCounterGenerator(param string) (Generator, error) { + if len(param) != 0 { + return nil, fmt.Errorf("packet counter param needs to be empty: %s", param) + } + + return &PacketCounterGenerator{}, nil +} + +type WaitResponseGenerator struct { +} + +func (c *WaitResponseGenerator) Generate() []byte { + WaitResponse.ShouldWait.Set() + <-WaitResponse.Channel + WaitResponse.ShouldWait.UnSet() + return []byte{} +} + +func (c *WaitResponseGenerator) Size() int { + return 0 +} + +func newWaitResponseGenerator(param string) (Generator, error) { + if len(param) != 0 { + return nil, fmt.Errorf("wait response param needs to be empty: %s", param) + } + + return &WaitResponseGenerator{}, nil +} diff --git a/device/awg/tag_generator_test.go b/device/awg/tag_generator_test.go new file mode 100644 index 0000000..4950b33 --- /dev/null +++ b/device/awg/tag_generator_test.go @@ -0,0 +1,189 @@ +package awg + +import ( + "encoding/binary" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_newBytesGenerator(t *testing.T) { + type args struct { + param string + } + tests := []struct { + name string + args args + want []byte + wantErr error + }{ + { + name: "empty", + args: args{ + param: "", + }, + wantErr: fmt.Errorf("not correct hex"), + }, + { + name: "wrong start", + args: args{ + param: "123456", + }, + wantErr: fmt.Errorf("not correct hex"), + }, + { + name: "not only hex value with X", + args: args{ + param: "0X12345q", + }, + wantErr: fmt.Errorf("not correct hex"), + }, + { + name: "not only hex value with x", + args: args{ + param: "0x12345q", + }, + wantErr: fmt.Errorf("not correct hex"), + }, + { + name: "valid hex", + args: args{ + param: "0xf6ab3267fa", + }, + want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa}, + }, + { + name: "valid hex with odd length", + args: args{ + param: "0xfab3267fa", + }, + want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newBytesGenerator(tt.args.param) + + if tt.wantErr != nil { + require.ErrorAs(t, err, &tt.wantErr) + require.Nil(t, got) + return + } + + require.Nil(t, err) + require.NotNil(t, got) + + gotValues := got.Generate() + require.Equal(t, tt.want, gotValues) + }) + } +} + +func Test_newRandomPacketGenerator(t *testing.T) { + type args struct { + param string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "empty", + args: args{ + param: "", + }, + wantErr: fmt.Errorf("parse int"), + }, + { + name: "not an int", + args: args{ + param: "x", + }, + wantErr: fmt.Errorf("parse int"), + }, + { + name: "too large", + args: args{ + param: "1001", + }, + wantErr: fmt.Errorf("random packet size must be less than 1000"), + }, + { + name: "valid", + args: args{ + param: "12", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newRandomPacketGenerator(tt.args.param) + if tt.wantErr != nil { + require.ErrorAs(t, err, &tt.wantErr) + require.Nil(t, got) + return + } + + require.Nil(t, err) + require.NotNil(t, got) + first := got.Generate() + + second := got.Generate() + require.NotEqual(t, first, second) + }) + } +} + +func TestPacketCounterGenerator(t *testing.T) { + tests := []struct { + name string + param string + wantErr bool + }{ + { + name: "Valid empty param", + param: "", + wantErr: false, + }, + { + name: "Invalid non-empty param", + param: "anything", + wantErr: true, + }, + } + + for _, tc := range tests { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + gen, err := newPacketCounterGenerator(tc.param) + if tc.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, 8, gen.Size()) + + // Reset counter to known value for test + initialCount := uint64(42) + PacketCounter.Store(initialCount) + + output := gen.Generate() + require.Equal(t, 8, len(output)) + + // Verify counter value in output + counterValue := binary.BigEndian.Uint64(output) + require.Equal(t, initialCount, counterValue) + + // Increment counter and verify change + PacketCounter.Add(1) + output = gen.Generate() + counterValue = binary.BigEndian.Uint64(output) + require.Equal(t, initialCount+1, counterValue) + }) + } +} diff --git a/device/awg/tag_junk_packet_generator.go b/device/awg/tag_junk_packet_generator.go new file mode 100644 index 0000000..fdbebc8 --- /dev/null +++ b/device/awg/tag_junk_packet_generator.go @@ -0,0 +1,59 @@ +package awg + +import ( + "fmt" + "strconv" +) + +type TagJunkPacketGenerator struct { + name string + tagValue string + + packetSize int + generators []Generator +} + +func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator { + return TagJunkPacketGenerator{ + name: name, + tagValue: tagValue, + generators: make([]Generator, 0, size), + } +} + +func (tg *TagJunkPacketGenerator) append(generator Generator) { + tg.generators = append(tg.generators, generator) + tg.packetSize += generator.Size() +} + +func (tg *TagJunkPacketGenerator) generatePacket() []byte { + packet := make([]byte, 0, tg.packetSize) + for _, generator := range tg.generators { + packet = append(packet, generator.Generate()...) + } + + return packet +} + +func (tg *TagJunkPacketGenerator) Name() string { + return tg.name +} + +func (tg *TagJunkPacketGenerator) nameIndex() (int, error) { + if len(tg.name) != 2 { + return 0, fmt.Errorf("name must be 2 character long: %s", tg.name) + } + + index, err := strconv.Atoi(tg.name[1:2]) + if err != nil { + return 0, fmt.Errorf("name 2 char should be an int %w", err) + } + return index, nil +} + +func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields { + return IpcFields{ + Key: tg.name, + Value: tg.tagValue, + } +} diff --git a/device/awg/tag_junk_packet_generator_test.go b/device/awg/tag_junk_packet_generator_test.go new file mode 100644 index 0000000..309d425 --- /dev/null +++ b/device/awg/tag_junk_packet_generator_test.go @@ -0,0 +1,210 @@ +package awg + +import ( + "testing" + + "github.com/amnezia-vpn/amneziawg-go/device/awg/internal" + "github.com/stretchr/testify/require" +) + +func TestNewTagJunkGenerator(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + genName string + size int + expected TagJunkPacketGenerator + }{ + { + name: "Create new generator with empty name", + genName: "", + size: 0, + expected: TagJunkPacketGenerator{ + name: "", + packetSize: 0, + generators: make([]Generator, 0), + }, + }, + { + name: "Create new generator with valid name", + genName: "T1", + size: 0, + expected: TagJunkPacketGenerator{ + name: "T1", + packetSize: 0, + generators: make([]Generator, 0), + }, + }, + { + name: "Create new generator with non-zero size", + genName: "T2", + size: 5, + expected: TagJunkPacketGenerator{ + name: "T2", + packetSize: 0, + generators: make([]Generator, 5), + }, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := newTagJunkPacketGenerator(tc.genName, "", tc.size) + require.Equal(t, tc.expected.name, result.name) + require.Equal(t, tc.expected.packetSize, result.packetSize) + require.Equal(t, cap(result.generators), len(tc.expected.generators)) + }) + } +} + +func TestTagJunkGeneratorAppend(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + initialState TagJunkPacketGenerator + mockSize int + expectedLength int + expectedSize int + }{ + { + name: "Append to empty generator", + initialState: newTagJunkPacketGenerator("T1", "", 0), + mockSize: 5, + expectedLength: 1, + expectedSize: 5, + }, + { + name: "Append to non-empty generator", + initialState: TagJunkPacketGenerator{ + name: "T2", + packetSize: 10, + generators: make([]Generator, 2), + }, + mockSize: 7, + expectedLength: 3, // 2 existing + 1 new + expectedSize: 17, // 10 + 7 + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := tc.initialState + mockGen := internal.NewMockGenerator(tc.mockSize) + + tg.append(mockGen) + + require.Equal(t, tc.expectedLength, len(tg.generators)) + require.Equal(t, tc.expectedSize, tg.packetSize) + }) + } +} + +func TestTagJunkGeneratorGenerate(t *testing.T) { + t.Parallel() + + // Create mock generators for testing + mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02}) + mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05}) + + testCases := []struct { + name string + setupGenerator func() TagJunkPacketGenerator + expected []byte + }{ + { + name: "Generate with empty generators", + setupGenerator: func() TagJunkPacketGenerator { + return newTagJunkPacketGenerator("T1", "", 0) + }, + expected: []byte{}, + }, + { + name: "Generate with single generator", + setupGenerator: func() TagJunkPacketGenerator { + tg := newTagJunkPacketGenerator("T2", "", 0) + tg.append(mockGen1) + return tg + }, + expected: []byte{0x01, 0x02}, + }, + { + name: "Generate with multiple generators", + setupGenerator: func() TagJunkPacketGenerator { + tg := newTagJunkPacketGenerator("T3", "", 0) + tg.append(mockGen1) + tg.append(mockGen2) + return tg + }, + expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := tc.setupGenerator() + result := tg.generatePacket() + + require.Equal(t, tc.expected, result) + }) + } +} + +func TestTagJunkGeneratorNameIndex(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + generatorName string + expectedIndex int + expectError bool + }{ + { + name: "Valid name with digit", + generatorName: "T5", + expectedIndex: 5, + expectError: false, + }, + { + name: "Invalid name - too short", + generatorName: "T", + expectError: true, + }, + { + name: "Invalid name - too long", + generatorName: "T55", + expectError: true, + }, + { + name: "Invalid name - non-digit second character", + generatorName: "TX", + expectError: true, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tg := TagJunkPacketGenerator{name: tc.generatorName} + index, err := tg.nameIndex() + + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedIndex, index) + } + }) + } +} diff --git a/device/awg/tag_junk_packet_generators.go b/device/awg/tag_junk_packet_generators.go new file mode 100644 index 0000000..9921eb0 --- /dev/null +++ b/device/awg/tag_junk_packet_generators.go @@ -0,0 +1,66 @@ +package awg + +import "fmt" + +type TagJunkPacketGenerators struct { + tagGenerators []TagJunkPacketGenerator + length int + DefaultJunkCount int // Jc +} + +func (generators *TagJunkPacketGenerators) AppendGenerator( + generator TagJunkPacketGenerator, +) { + generators.tagGenerators = append(generators.tagGenerators, generator) + generators.length++ +} + +func (generators *TagJunkPacketGenerators) IsDefined() bool { + return len(generators.tagGenerators) > 0 +} + +// validate that packets were defined consecutively +func (generators *TagJunkPacketGenerators) Validate() error { + seen := make([]bool, len(generators.tagGenerators)) + for _, generator := range generators.tagGenerators { + index, err := generator.nameIndex() + if index > len(generators.tagGenerators) { + return fmt.Errorf("junk packet index should be consecutive") + } + if err != nil { + return fmt.Errorf("name index: %w", err) + } else { + seen[index-1] = true + } + } + + for _, found := range seen { + if !found { + return fmt.Errorf("junk packet index should be consecutive") + } + } + + return nil +} + +func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte { + var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount) + + for i, tagGenerator := range generators.tagGenerators { + rv = append(rv, make([]byte, tagGenerator.packetSize)) + copy(rv[i], tagGenerator.generatePacket()) + PacketCounter.Inc() + } + PacketCounter.Add(uint64(generators.DefaultJunkCount)) + + return rv +} + +func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields { + rv := make([]IpcFields, 0, len(tg.tagGenerators)) + for _, generator := range tg.tagGenerators { + rv = append(rv, generator.IpcGetFields()) + } + + return rv +} diff --git a/device/awg/tag_junk_packet_generators_test.go b/device/awg/tag_junk_packet_generators_test.go new file mode 100644 index 0000000..6b1fd47 --- /dev/null +++ b/device/awg/tag_junk_packet_generators_test.go @@ -0,0 +1,149 @@ +package awg + +import ( + "testing" + + "github.com/amnezia-vpn/amneziawg-go/device/awg/internal" + "github.com/stretchr/testify/require" +) + +func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) { + tests := []struct { + name string + generator TagJunkPacketGenerator + }{ + { + name: "append single generator", + generator: newTagJunkPacketGenerator("t1", "", 10), + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + generators := &TagJunkPacketGenerators{} + + // Initial length should be 0 + require.Equal(t, 0, generators.length) + require.Empty(t, generators.tagGenerators) + + // After append, length should be 1 and generator should be added + generators.AppendGenerator(tt.generator) + require.Equal(t, 1, generators.length) + require.Len(t, generators.tagGenerators, 1) + require.Equal(t, tt.generator, generators.tagGenerators[0]) + }) + } +} + +func TestTagJunkGeneratorHandlerValidate(t *testing.T) { + tests := []struct { + name string + generators []TagJunkPacketGenerator + wantErr bool + errMsg string + }{ + { + name: "bad start", + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t3", "", 10), + newTagJunkPacketGenerator("t4", "", 10), + }, + wantErr: true, + errMsg: "junk packet index should be consecutive", + }, + { + name: "non-consecutive indices", + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t1", "", 10), + newTagJunkPacketGenerator("t3", "", 10), // Missing t2 + }, + wantErr: true, + errMsg: "junk packet index should be consecutive", + }, + { + name: "consecutive indices", + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("t1", "", 10), + newTagJunkPacketGenerator("t2", "", 10), + newTagJunkPacketGenerator("t3", "", 10), + newTagJunkPacketGenerator("t4", "", 10), + newTagJunkPacketGenerator("t5", "", 10), + }, + }, + { + name: "nameIndex error", + generators: []TagJunkPacketGenerator{ + newTagJunkPacketGenerator("error", "", 10), + }, + wantErr: true, + errMsg: "name must be 2 character long", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + generators := &TagJunkPacketGenerators{} + for _, gen := range tt.generators { + generators.AppendGenerator(gen) + } + + err := generators.Validate() + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + return + } + require.NoError(t, err) + }) + } +} + +func TestTagJunkGeneratorHandlerGenerate(t *testing.T) { + mockByte1 := []byte{0x01, 0x02} + mockByte2 := []byte{0x03, 0x04, 0x05} + mockGen1 := internal.NewMockByteGenerator(mockByte1) + mockGen2 := internal.NewMockByteGenerator(mockByte2) + + tests := []struct { + name string + setupGenerator func() []TagJunkPacketGenerator + expected [][]byte + }{ + { + name: "generate with no default junk", + setupGenerator: func() []TagJunkPacketGenerator { + tg1 := newTagJunkPacketGenerator("t1", "", 0) + tg1.append(mockGen1) + tg1.append(mockGen2) + tg2 := newTagJunkPacketGenerator("t2", "", 0) + tg2.append(mockGen2) + tg2.append(mockGen1) + + return []TagJunkPacketGenerator{tg1, tg2} + }, + expected: [][]byte{ + append(mockByte1, mockByte2...), + append(mockByte2, mockByte1...), + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + generators := &TagJunkPacketGenerators{} + tagGenerators := tt.setupGenerator() + for _, gen := range tagGenerators { + generators.AppendGenerator(gen) + } + + result := generators.GeneratePackets() + require.Equal(t, result, tt.expected) + }) + } +} diff --git a/device/awg/tag_parser.go b/device/awg/tag_parser.go new file mode 100644 index 0000000..2b09226 --- /dev/null +++ b/device/awg/tag_parser.go @@ -0,0 +1,112 @@ +package awg + +import ( + "fmt" + "maps" + "regexp" + "strings" +) + +type IpcFields struct{ Key, Value string } + +type EnumTag string + +const ( + BytesEnumTag EnumTag = "b" + CounterEnumTag EnumTag = "c" + TimestampEnumTag EnumTag = "t" + RandomBytesEnumTag EnumTag = "r" + WaitTimeoutEnumTag EnumTag = "wt" + WaitResponseEnumTag EnumTag = "wr" +) + +var generatorCreator = map[EnumTag]newGenerator{ + BytesEnumTag: newBytesGenerator, + CounterEnumTag: newPacketCounterGenerator, + TimestampEnumTag: newTimestampGenerator, + RandomBytesEnumTag: newRandomPacketGenerator, + WaitTimeoutEnumTag: newWaitTimeoutGenerator, + // WaitResponseEnumTag: newWaitResponseGenerator, +} + +// helper map to determine enumTags are unique +var uniqueTags = map[EnumTag]bool{ + CounterEnumTag: false, + TimestampEnumTag: false, +} + +type Tag struct { + Name EnumTag + Param string +} + +func parseTag(input string) (Tag, error) { + // Regular expression to match + re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`) + + match := re.FindStringSubmatch(input) + tag := Tag{ + Name: EnumTag(match[1]), + } + if len(match) > 2 && match[2] != "" { + tag.Param = strings.TrimSpace(match[2]) + } + + return tag, nil +} + +func Parse(name, input string) (TagJunkPacketGenerator, error) { + inputSlice := strings.Split(input, "<") + if len(inputSlice) <= 1 { + return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input) + } + + uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags)) + maps.Copy(uniqueTagCheck, uniqueTags) + + // skip byproduct of split + inputSlice = inputSlice[1:] + rv := newTagJunkPacketGenerator(name, input, len(inputSlice)) + for _, inputParam := range inputSlice { + if len(inputParam) <= 1 { + return TagJunkPacketGenerator{}, fmt.Errorf( + "empty tag in input: %s", + inputSlice, + ) + } else if strings.Count(inputParam, ">") != 1 { + return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input) + } + + tag, _ := parseTag(inputParam) + creator, ok := generatorCreator[tag.Name] + if !ok { + return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name) + } + if present, ok := uniqueTagCheck[tag.Name]; ok { + if present { + return TagJunkPacketGenerator{}, fmt.Errorf( + "tag %s needs to be unique", + tag.Name, + ) + } + uniqueTagCheck[tag.Name] = true + } + generator, err := creator(tag.Param) + if err != nil { + return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err) + } + + // TODO: handle counter tag + // if tag.Name == CounterEnumTag { + // packetCounter, ok := generator.(*PacketCounterGenerator) + // if !ok { + // log.Fatalf("packet counter generator expected, got %T", generator) + // } + // PacketCounter = packetCounter.counter + // } + + rv.append(generator) + } + + return rv, nil +} diff --git a/device/awg/tag_parser_test.go b/device/awg/tag_parser_test.go new file mode 100644 index 0000000..8f828ec --- /dev/null +++ b/device/awg/tag_parser_test.go @@ -0,0 +1,77 @@ +package awg + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParse(t *testing.T) { + type args struct { + name string + input string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "invalid name", + args: args{name: "apple", input: ""}, + wantErr: fmt.Errorf("ill formated input"), + }, + { + name: "empty", + args: args{name: "i1", input: ""}, + wantErr: fmt.Errorf("ill formated input"), + }, + { + name: "extra >", + args: args{name: "i1", input: ">"}, + wantErr: fmt.Errorf("ill formated input"), + }, + { + name: "extra <", + args: args{name: "i1", input: "<"}, + wantErr: fmt.Errorf("empty tag in input"), + }, + { + name: "empty <>", + args: args{name: "i1", input: "<>"}, + wantErr: fmt.Errorf("empty tag in input"), + }, + { + name: "invalid tag", + args: args{name: "i1", input: ""}, + wantErr: fmt.Errorf("invalid tag"), + }, + { + name: "counter uniqueness violation", + args: args{name: "i1", input: ""}, + wantErr: fmt.Errorf("parse tag needs to be unique"), + }, + { + name: "timestamp uniqueness violation", + args: args{name: "i1", input: ""}, + wantErr: fmt.Errorf("parse tag needs to be unique"), + }, + { + name: "valid", + args: args{input: ""}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Parse(tt.args.name, tt.args.input) + + // TODO: ErrorAs doesn't work as you think + if tt.wantErr != nil { + require.ErrorAs(t, err, &tt.wantErr) + return + } + require.Nil(t, err) + }) + } +} diff --git a/device/device.go b/device/device.go index 124b74e..1829352 100644 --- a/device/device.go +++ b/device/device.go @@ -6,19 +6,55 @@ package device import ( + "errors" "runtime" "sync" "sync/atomic" "time" "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device/awg" "github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/rwcancel" "github.com/amnezia-vpn/amneziawg-go/tun" - "github.com/tevino/abool/v2" ) +type Version uint8 + +const ( + VersionDefault Version = iota + VersionAwg + VersionAwgSpecialHandshake +) + +// TODO: +type AtomicVersion struct { + value atomic.Uint32 +} + +func NewAtomicVersion(v Version) *AtomicVersion { + av := &AtomicVersion{} + av.Store(v) + return av +} + +func (av *AtomicVersion) Load() Version { + return Version(av.value.Load()) +} + +func (av *AtomicVersion) Store(v Version) { + av.value.Store(uint32(v)) +} + +func (av *AtomicVersion) CompareAndSwap(old, new Version) bool { + return av.value.CompareAndSwap(uint32(old), uint32(new)) +} + +func (av *AtomicVersion) Swap(new Version) Version { + return Version(av.value.Swap(uint32(new))) +} + type Device struct { state struct { // state holds the device's state. It is accessed atomically. @@ -92,23 +128,8 @@ type Device struct { closed chan struct{} log *Logger - isASecOn abool.AtomicBool - aSecMux sync.RWMutex - aSecCfg aSecCfgType - junkCreator junkCreator -} - -type aSecCfgType struct { - isSet bool - junkPacketCount int - junkPacketMinSize int - junkPacketMaxSize int - initPacketJunkSize int - responsePacketJunkSize int - initPacketMagicHeader uint32 - responsePacketMagicHeader uint32 - underloadPacketMagicHeader uint32 - transportPacketMagicHeader uint32 + version Version + awg awg.Protocol } // deviceState represents the state of a Device. @@ -557,251 +578,261 @@ func (device *Device) BindClose() error { device.net.Unlock() return err } -func (device *Device) isAdvancedSecurityOn() bool { - return device.isASecOn.IsSet() +func (device *Device) isAWG() bool { + return device.version >= VersionAwg } func (device *Device) resetProtocol() { // restore default message type values - MessageInitiationType = 1 - MessageResponseType = 2 - MessageCookieReplyType = 3 - MessageTransportType = 4 + MessageInitiationType = DefaultMessageInitiationType + MessageResponseType = DefaultMessageResponseType + MessageCookieReplyType = DefaultMessageCookieReplyType + MessageTransportType = DefaultMessageTransportType } -func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { - - if !tempASecCfg.isSet { - return err +func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error { + if !tempAwg.ASecCfg.IsSet && !tempAwg.HandshakeHandler.IsSet { + return nil } + var errs []error + isASecOn := false - device.aSecMux.Lock() - if tempASecCfg.junkPacketCount < 0 { - err = ipcErrorf( + device.awg.ASecMux.Lock() + if tempAwg.ASecCfg.JunkPacketCount < 0 { + errs = append(errs, ipcErrorf( ipc.IpcErrorInvalid, "JunkPacketCount should be non negative", + ), ) } - device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount - if tempASecCfg.junkPacketCount != 0 { + device.awg.ASecCfg.JunkPacketCount = tempAwg.ASecCfg.JunkPacketCount + if tempAwg.ASecCfg.JunkPacketCount != 0 { isASecOn = true } - device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize - if tempASecCfg.junkPacketMinSize != 0 { + device.awg.ASecCfg.JunkPacketMinSize = tempAwg.ASecCfg.JunkPacketMinSize + if tempAwg.ASecCfg.JunkPacketMinSize != 0 { isASecOn = true } - if device.aSecCfg.junkPacketCount > 0 && - tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize { + if device.awg.ASecCfg.JunkPacketCount > 0 && + tempAwg.ASecCfg.JunkPacketMaxSize == tempAwg.ASecCfg.JunkPacketMinSize { - tempASecCfg.junkPacketMaxSize++ // to make rand gen work + tempAwg.ASecCfg.JunkPacketMaxSize++ // to make rand gen work } - if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize { - device.aSecCfg.junkPacketMinSize = 0 - device.aSecCfg.junkPacketMaxSize = 1 - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w", - tempASecCfg.junkPacketMaxSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", - tempASecCfg.junkPacketMaxSize, - MaxSegmentSize, - ) - } - } else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "maxSize: %d; should be greater than minSize: %d; %w", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - "maxSize: %d; should be greater than minSize: %d", - tempASecCfg.junkPacketMaxSize, - tempASecCfg.junkPacketMinSize, - ) - } + if tempAwg.ASecCfg.JunkPacketMaxSize >= MaxSegmentSize { + device.awg.ASecCfg.JunkPacketMinSize = 0 + device.awg.ASecCfg.JunkPacketMaxSize = 1 + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + "JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d", + tempAwg.ASecCfg.JunkPacketMaxSize, + MaxSegmentSize, + )) + } else if tempAwg.ASecCfg.JunkPacketMaxSize < tempAwg.ASecCfg.JunkPacketMinSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + "maxSize: %d; should be greater than minSize: %d", + tempAwg.ASecCfg.JunkPacketMaxSize, + tempAwg.ASecCfg.JunkPacketMinSize, + )) } else { - device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize + device.awg.ASecCfg.JunkPacketMaxSize = tempAwg.ASecCfg.JunkPacketMaxSize } - if tempASecCfg.junkPacketMaxSize != 0 { + if tempAwg.ASecCfg.JunkPacketMaxSize != 0 { isASecOn = true } - if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.initPacketJunkSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.initPacketJunkSize, - MaxSegmentSize, - ) - } + newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize + + if newInitSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.InitHeaderJunkSize, + MaxSegmentSize, + ), + ) } else { - device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize + device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize } - if tempASecCfg.initPacketJunkSize != 0 { + if tempAwg.ASecCfg.InitHeaderJunkSize != 0 { isASecOn = true } - if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`, - tempASecCfg.responsePacketJunkSize, - MaxSegmentSize, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, - tempASecCfg.responsePacketJunkSize, - MaxSegmentSize, - ) - } + newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize + + if newResponseSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.ResponseHeaderJunkSize, + MaxSegmentSize, + ), + ) } else { - device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize + device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize } - if tempASecCfg.responsePacketJunkSize != 0 { + if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 { isASecOn = true } - if tempASecCfg.initPacketMagicHeader > 4 { + newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize + + if newCookieSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.CookieReplyHeaderJunkSize, + MaxSegmentSize, + ), + ) + } else { + device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize + } + + if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 { + isASecOn = true + } + + newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize + + if newTransportSize >= MaxSegmentSize { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`, + tempAwg.ASecCfg.TransportHeaderJunkSize, + MaxSegmentSize, + ), + ) + } else { + device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize + } + + if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 { + isASecOn = true + } + + if tempAwg.ASecCfg.InitPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating init_packet_magic_header") - device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader - MessageInitiationType = device.aSecCfg.initPacketMagicHeader + device.awg.ASecCfg.InitPacketMagicHeader = tempAwg.ASecCfg.InitPacketMagicHeader + MessageInitiationType = device.awg.ASecCfg.InitPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") - MessageInitiationType = 1 + MessageInitiationType = DefaultMessageInitiationType } - if tempASecCfg.responsePacketMagicHeader > 4 { + if tempAwg.ASecCfg.ResponsePacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating response_packet_magic_header") - device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader - MessageResponseType = device.aSecCfg.responsePacketMagicHeader + device.awg.ASecCfg.ResponsePacketMagicHeader = tempAwg.ASecCfg.ResponsePacketMagicHeader + MessageResponseType = device.awg.ASecCfg.ResponsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") - MessageResponseType = 2 + MessageResponseType = DefaultMessageResponseType } - if tempASecCfg.underloadPacketMagicHeader > 4 { + if tempAwg.ASecCfg.UnderloadPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating underload_packet_magic_header") - device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader - MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader + device.awg.ASecCfg.UnderloadPacketMagicHeader = tempAwg.ASecCfg.UnderloadPacketMagicHeader + MessageCookieReplyType = device.awg.ASecCfg.UnderloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") - MessageCookieReplyType = 3 + MessageCookieReplyType = DefaultMessageCookieReplyType } - if tempASecCfg.transportPacketMagicHeader > 4 { + if tempAwg.ASecCfg.TransportPacketMagicHeader > 4 { isASecOn = true device.log.Verbosef("UAPI: Updating transport_packet_magic_header") - device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader - MessageTransportType = device.aSecCfg.transportPacketMagicHeader + device.awg.ASecCfg.TransportPacketMagicHeader = tempAwg.ASecCfg.TransportPacketMagicHeader + MessageTransportType = device.awg.ASecCfg.TransportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") - MessageTransportType = 4 + MessageTransportType = DefaultMessageTransportType } - isSameMap := map[uint32]bool{} - isSameMap[MessageInitiationType] = true - isSameMap[MessageResponseType] = true - isSameMap[MessageCookieReplyType] = true - isSameMap[MessageTransportType] = true + isSameHeaderMap := map[uint32]struct{}{ + MessageInitiationType: {}, + MessageResponseType: {}, + MessageCookieReplyType: {}, + MessageTransportType: {}, + } // size will be different if same values - if len(isSameMap) != 4 { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - err, - ) - } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, - MessageInitiationType, - MessageResponseType, - MessageCookieReplyType, - MessageTransportType, - ) + if len(isSameHeaderMap) != 4 { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`, + MessageInitiationType, + MessageResponseType, + MessageCookieReplyType, + MessageTransportType, + ), + ) + } + + isSameSizeMap := map[int]struct{}{ + newInitSize: {}, + newResponseSize: {}, + newCookieSize: {}, + newTransportSize: {}, + } + + if len(isSameSizeMap) != 4 { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, + `new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`, + newInitSize, + newResponseSize, + newCookieSize, + newTransportSize, + ), + ) + } else { + msgTypeToJunkSize = map[uint32]int{ + MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize, + MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize, + MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize, + MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize, + } + + packetSizeToMsgType = map[int]uint32{ + newInitSize: MessageInitiationType, + newResponseSize: MessageResponseType, + newCookieSize: MessageCookieReplyType, + newTransportSize: MessageTransportType, } } - newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize - newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize + device.awg.IsASecOn.SetTo(isASecOn) + var err error + device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg) + if err != nil { + errs = append(errs, err) + } - if newInitSize == newResponseSize { - if err != nil { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ; %w`, - newInitSize, - newResponseSize, - err, - ) + if tempAwg.HandshakeHandler.IsSet { + if err := tempAwg.HandshakeHandler.Validate(); err != nil { + errs = append(errs, ipcErrorf( + ipc.IpcErrorInvalid, "handshake handler validate: %w", err)) } else { - err = ipcErrorf( - ipc.IpcErrorInvalid, - `new init size:%d; and new response size:%d; should differ`, - newInitSize, - newResponseSize, - ) + device.awg.HandshakeHandler = tempAwg.HandshakeHandler + device.awg.HandshakeHandler.ControlledJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.ASecCfg.JunkPacketCount + device.version = VersionAwgSpecialHandshake } } else { - packetSizeToMsgType = map[int]uint32{ - newInitSize: MessageInitiationType, - newResponseSize: MessageResponseType, - MessageCookieReplySize: MessageCookieReplyType, - MessageTransportSize: MessageTransportType, - } - - msgTypeToJunkSize = map[uint32]int{ - MessageInitiationType: device.aSecCfg.initPacketJunkSize, - MessageResponseType: device.aSecCfg.responsePacketJunkSize, - MessageCookieReplyType: 0, - MessageTransportType: 0, - } + device.version = VersionAwg } - device.isASecOn.SetTo(isASecOn) - device.junkCreator, err = NewJunkCreator(device) - device.aSecMux.Unlock() + device.awg.ASecMux.Unlock() - return err + return errors.Join(errs...) } diff --git a/device/device_test.go b/device/device_test.go index f66d326..5824cf9 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -7,19 +7,22 @@ package device import ( "bytes" + "context" "encoding/hex" "fmt" "io" "math/rand" "net/netip" "os" + "os/signal" "runtime" "runtime/pprof" "sync" - "sync/atomic" "testing" "time" + "go.uber.org/atomic" + "github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" "github.com/amnezia-vpn/amneziawg-go/tun" @@ -50,7 +53,7 @@ func uapiCfg(cfg ...string) string { // genConfigs generates a pair of configs that connect to each other. // The configs use distinct, probably-usable ports. -func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { +func genConfigs(tb testing.TB, cfg ...string) (cfgs, endpointCfgs [2]string) { var key1, key2 NoisePrivateKey _, err := rand.Read(key1[:]) if err != nil { @@ -62,7 +65,8 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { } pub1, pub2 := key1.publicKey(), key2.publicKey() - cfgs[0] = uapiCfg( + args0 := append([]string(nil), cfg...) + args0 = append(args0, []string{ "private_key", hex.EncodeToString(key1[:]), "listen_port", "0", "replace_peers", "true", @@ -70,12 +74,16 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "protocol_version", "1", "replace_allowed_ips", "true", "allowed_ip", "1.0.0.2/32", - ) + }...) + cfgs[0] = uapiCfg(args0...) + endpointCfgs[0] = uapiCfg( "public_key", hex.EncodeToString(pub2[:]), "endpoint", "127.0.0.1:%d", ) - cfgs[1] = uapiCfg( + + args1 := append([]string(nil), cfg...) + args1 = append(args1, []string{ "private_key", hex.EncodeToString(key2[:]), "listen_port", "0", "replace_peers", "true", @@ -83,66 +91,9 @@ func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { "protocol_version", "1", "replace_allowed_ips", "true", "allowed_ip", "1.0.0.1/32", - ) - endpointCfgs[1] = uapiCfg( - "public_key", hex.EncodeToString(pub1[:]), - "endpoint", "127.0.0.1:%d", - ) - return -} + }...) -func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { - var key1, key2 NoisePrivateKey - _, err := rand.Read(key1[:]) - if err != nil { - tb.Errorf("unable to generate private key random bytes: %v", err) - } - _, err = rand.Read(key2[:]) - if err != nil { - tb.Errorf("unable to generate private key random bytes: %v", err) - } - pub1, pub2 := key1.publicKey(), key2.publicKey() - - cfgs[0] = uapiCfg( - "private_key", hex.EncodeToString(key1[:]), - "listen_port", "0", - "replace_peers", "true", - "jc", "5", - "jmin", "500", - "jmax", "1000", - "s1", "30", - "s2", "40", - "h1", "123456", - "h2", "67543", - "h4", "32345", - "h3", "123123", - "public_key", hex.EncodeToString(pub2[:]), - "protocol_version", "1", - "replace_allowed_ips", "true", - "allowed_ip", "1.0.0.2/32", - ) - endpointCfgs[0] = uapiCfg( - "public_key", hex.EncodeToString(pub2[:]), - "endpoint", "127.0.0.1:%d", - ) - cfgs[1] = uapiCfg( - "private_key", hex.EncodeToString(key2[:]), - "listen_port", "0", - "replace_peers", "true", - "jc", "5", - "jmin", "500", - "jmax", "1000", - "s1", "30", - "s2", "40", - "h1", "123456", - "h2", "67543", - "h4", "32345", - "h3", "123123", - "public_key", hex.EncodeToString(pub1[:]), - "protocol_version", "1", - "replace_allowed_ips", "true", - "allowed_ip", "1.0.0.1/32", - ) + cfgs[1] = uapiCfg(args1...) endpointCfgs[1] = uapiCfg( "public_key", hex.EncodeToString(pub1[:]), "endpoint", "127.0.0.1:%d", @@ -185,9 +136,10 @@ func (pair *testPair) Send( // pong is the new ping p0, p1 = p1, p0 } + msg := tuntest.Ping(p0.ip, p1.ip) p1.tun.Outbound <- msg - timer := time.NewTimer(5 * time.Second) + timer := time.NewTimer(6 * time.Second) defer timer.Stop() var err error select { @@ -214,14 +166,12 @@ func (pair *testPair) Send( // genTestPair creates a testPair. func genTestPair( tb testing.TB, - realSocket, withASecurity bool, + realSocket bool, + extraCfg ...string, ) (pair testPair) { var cfg, endpointCfg [2]string - if withASecurity { - cfg, endpointCfg = genASecurityConfigs(tb) - } else { - cfg, endpointCfg = genConfigs(tb) - } + cfg, endpointCfg = genConfigs(tb, extraCfg...) + var binds [2]conn.Bind if realSocket { binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() @@ -265,7 +215,7 @@ func genTestPair( func TestTwoDevicePing(t *testing.T) { goroutineLeakCheck(t) - pair := genTestPair(t, true, false) + pair := genTestPair(t, true) t.Run("ping 1.0.0.1", func(t *testing.T) { pair.Send(t, Ping, nil) }) @@ -274,9 +224,23 @@ func TestTwoDevicePing(t *testing.T) { }) } -func TestASecurityTwoDevicePing(t *testing.T) { +// Run test with -race=false to avoid the race for setting the default msgTypes 2 times +func TestAWGDevicePing(t *testing.T) { goroutineLeakCheck(t) - pair := genTestPair(t, true, true) + + pair := genTestPair(t, true, + "jc", "5", + "jmin", "500", + "jmax", "1000", + "s1", "30", + "s2", "40", + "s3", "50", + "s4", "5", + "h1", "123456", + "h2", "67543", + "h3", "123123", + "h4", "32345", + ) t.Run("ping 1.0.0.1", func(t *testing.T) { pair.Send(t, Ping, nil) }) @@ -285,13 +249,58 @@ func TestASecurityTwoDevicePing(t *testing.T) { }) } +// Needs to be stopped with Ctrl-C +func TestAWGHandshakeDevicePing(t *testing.T) { + t.Skip("This test is intended to be run manually, not as part of the test suite.") + + signalContext, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + isRunning := atomic.NewBool(true) + go func() { + <-signalContext.Done() + fmt.Println("Waiting to finish") + isRunning.Store(false) + }() + + goroutineLeakCheck(t) + pair := genTestPair(t, true, + "i1", "", + "i2", "", + "j1", "", + "j2", "", + "j3", "", + "itime", "60", + // "jc", "1", + // "jmin", "500", + // "jmax", "1000", + // "s1", "30", + // "s2", "40", + // "h1", "123456", + // "h2", "67543", + // "h4", "32345", + // "h3", "123123", + ) + t.Run("ping 1.0.0.1", func(t *testing.T) { + for isRunning.Load() { + pair.Send(t, Ping, nil) + time.Sleep(2 * time.Second) + } + }) + t.Run("ping 1.0.0.2", func(t *testing.T) { + for isRunning.Load() { + pair.Send(t, Pong, nil) + time.Sleep(2 * time.Second) + } + }) +} + func TestUpDown(t *testing.T) { goroutineLeakCheck(t) const itrials = 50 const otrials = 10 for n := 0; n < otrials; n++ { - pair := genTestPair(t, false, false) + pair := genTestPair(t, false) for i := range pair { for k := range pair[i].dev.peers.keyMap { pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) @@ -325,7 +334,7 @@ func TestUpDown(t *testing.T) { // TestConcurrencySafety does other things concurrently with tunnel use. // It is intended to be used with the race detector to catch data races. func TestConcurrencySafety(t *testing.T) { - pair := genTestPair(t, true, false) + pair := genTestPair(t, true) done := make(chan struct{}) const warmupIters = 10 @@ -406,7 +415,7 @@ func TestConcurrencySafety(t *testing.T) { } func BenchmarkLatency(b *testing.B) { - pair := genTestPair(b, true, false) + pair := genTestPair(b, true) // Establish a connection. pair.Send(b, Ping, nil) @@ -420,7 +429,7 @@ func BenchmarkLatency(b *testing.B) { } func BenchmarkThroughput(b *testing.B) { - pair := genTestPair(b, true, false) + pair := genTestPair(b, true) // Establish a connection. pair.Send(b, Ping, nil) @@ -464,7 +473,7 @@ func BenchmarkThroughput(b *testing.B) { } func BenchmarkUAPIGet(b *testing.B) { - pair := genTestPair(b, true, false) + pair := genTestPair(b, true) pair.Send(b, Ping, nil) pair.Send(b, Pong, nil) b.ReportAllocs() diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 789eb16..f637b24 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -52,11 +52,18 @@ const ( WGLabelCookie = "cookie--" ) +const ( + DefaultMessageInitiationType uint32 = 1 + DefaultMessageResponseType uint32 = 2 + DefaultMessageCookieReplyType uint32 = 3 + DefaultMessageTransportType uint32 = 4 +) + var ( - MessageInitiationType uint32 = 1 - MessageResponseType uint32 = 2 - MessageCookieReplyType uint32 = 3 - MessageTransportType uint32 = 4 + MessageInitiationType uint32 = DefaultMessageInitiationType + MessageResponseType uint32 = DefaultMessageResponseType + MessageCookieReplyType uint32 = DefaultMessageCookieReplyType + MessageTransportType uint32 = DefaultMessageTransportType ) const ( @@ -75,9 +82,10 @@ const ( MessageTransportOffsetContent = 16 ) -var packetSizeToMsgType map[int]uint32 - -var msgTypeToJunkSize map[uint32]int +var ( + packetSizeToMsgType map[int]uint32 + msgTypeToJunkSize map[uint32]int +) /* Type is an 8-bit field, followed by 3 nul bytes, * by marshalling the messages in little-endian byteorder @@ -197,12 +205,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(handshake.remoteStatic[:]) - device.aSecMux.RLock() + device.awg.ASecMux.RLock() msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), } - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) @@ -256,12 +264,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey [blake2s.Size]byte ) - device.aSecMux.RLock() + device.awg.ASecMux.RLock() if msg.Type != MessageInitiationType { - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -376,9 +384,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } var msg MessageResponse - device.aSecMux.RLock() + device.awg.ASecMux.RLock() msg.Type = MessageResponseType - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() msg.Sender = handshake.localIndex msg.Receiver = handshake.remoteIndex @@ -428,12 +436,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - device.aSecMux.RLock() + device.awg.ASecMux.RLock() if msg.Type != MessageResponseType { - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() return nil } - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() // lookup handshake by receiver diff --git a/device/peer.go b/device/peer.go index 8f88b2a..e8a5168 100644 --- a/device/peer.go +++ b/device/peer.go @@ -13,6 +13,7 @@ import ( "time" "github.com/amnezia-vpn/amneziawg-go/conn" + "github.com/amnezia-vpn/amneziawg-go/device/awg" ) type Peer struct { @@ -113,6 +114,16 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error { + err := peer.SendBuffers(buffers) + if err == nil { + awg.PacketCounter.Add(uint64(len(buffers))) + return nil + } + + return err +} + func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() diff --git a/device/receive.go b/device/receive.go index 0a4910a..6daba0d 100644 --- a/device/receive.go +++ b/device/receive.go @@ -129,7 +129,7 @@ func (device *Device) RoutineReceiveIncoming( } deathSpiral = 0 - device.aSecMux.RLock() + device.awg.ASecMux.RLock() // handle each packet in the batch for i, size := range sizes[:count] { if size < MinMessageSize { @@ -137,10 +137,14 @@ func (device *Device) RoutineReceiveIncoming( } // check size of packet - packet := bufsArrs[i][:size] var msgType uint32 - if device.isAdvancedSecurityOn() { + if device.isAWG() { + // TODO: + // if awg.WaitResponse.ShouldWait.IsSet() { + // awg.WaitResponse.Channel <- struct{}{} + // } + if assumedMsgType, ok := packetSizeToMsgType[size]; ok { junkSize := msgTypeToJunkSize[assumedMsgType] // transport size can align with other header types; @@ -149,19 +153,29 @@ func (device *Device) RoutineReceiveIncoming( if msgType == assumedMsgType { packet = packet[junkSize:] } else { - device.log.Verbosef("Transport packet lined up with another msg type") + device.log.Verbosef("transport packet lined up with another msg type") msgType = binary.LittleEndian.Uint32(packet[:4]) } } else { - msgType = binary.LittleEndian.Uint32(packet[:4]) + transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize + msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4]) if msgType != MessageTransportType { - device.log.Verbosef("ASec: Received message with unknown type") + // probably a junk packet + device.log.Verbosef("aSec: Received message with unknown type: %d", msgType) continue } + + // remove junk from bufsArrs by shifting the packet + // this buffer is also used for decryption, so it needs to be corrected + copy(bufsArrs[i][:size], packet[transportJunkSize:]) + size -= transportJunkSize + // need to reinitialize packet as well + packet = packet[:size] } } else { msgType = binary.LittleEndian.Uint32(packet[:4]) } + switch msgType { // check if transport @@ -245,7 +259,7 @@ func (device *Device) RoutineReceiveIncoming( default: } } - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer @@ -304,7 +318,7 @@ func (device *Device) RoutineHandshake(id int) { for elem := range device.queue.handshake.c { - device.aSecMux.RLock() + device.awg.ASecMux.RLock() // handle cookie fields and ratelimiting @@ -456,7 +470,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.aSecMux.RUnlock() + device.awg.ASecMux.RUnlock() device.PutMessageBuffer(elem.buffer) } } diff --git a/device/send.go b/device/send.go index 7f0faa3..04ca2ad 100644 --- a/device/send.go +++ b/device/send.go @@ -124,12 +124,30 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } var sendBuffer [][]byte + // so only packet processed for cookie generation var junkedHeader []byte - if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - junks, err := peer.device.junkCreator.createJunkPackets() - peer.device.aSecMux.RUnlock() + if peer.device.version >= VersionAwg { + var junks [][]byte + if peer.device.version == VersionAwgSpecialHandshake { + peer.device.awg.ASecMux.RLock() + // set junks depending on packet type + junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk() + if junks == nil { + junks = peer.device.awg.HandshakeHandler.GenerateControlledJunk() + if junks != nil { + peer.device.log.Verbosef("%v - Controlled junks sent", peer) + } + } else { + peer.device.log.Verbosef("%v - Special junks sent", peer) + } + peer.device.awg.ASecMux.RUnlock() + } else { + junks = make([][]byte, 0, peer.device.awg.ASecCfg.JunkPacketCount) + } + peer.device.awg.ASecMux.RLock() + err := peer.device.awg.JunkCreator.CreateJunkPackets(&junks) + peer.device.awg.ASecMux.RUnlock() if err != nil { peer.device.log.Errorf("%v - %v", peer, err) @@ -145,19 +163,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } } - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.initPacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize) - writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) - if err != nil { - peer.device.log.Errorf("%v - %v", peer, err) - peer.device.aSecMux.RUnlock() - return err - } - junkedHeader = writer.Bytes() + junkedHeader, err = peer.device.awg.CreateInitHeaderJunk() + if err != nil { + peer.device.log.Errorf("%v - %v", peer, err) + return err } - peer.device.aSecMux.RUnlock() } var buf [MessageInitiationSize]byte @@ -172,7 +182,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { sendBuffer = append(sendBuffer, junkedHeader) - err = peer.SendBuffers(sendBuffer) + err = peer.SendAndCountBuffers(sendBuffer) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -193,22 +203,13 @@ func (peer *Peer) SendHandshakeResponse() error { peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) return err } - var junkedHeader []byte - if peer.device.isAdvancedSecurityOn() { - peer.device.aSecMux.RLock() - if peer.device.aSecCfg.responsePacketJunkSize != 0 { - buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) - writer := bytes.NewBuffer(buf[:0]) - err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) - if err != nil { - peer.device.aSecMux.RUnlock() - peer.device.log.Errorf("%v - %v", peer, err) - return err - } - junkedHeader = writer.Bytes() - } - peer.device.aSecMux.RUnlock() + + junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk() + if err != nil { + peer.device.log.Errorf("%v - %v", peer, err) + return err } + var buf [MessageResponseSize]byte writer := bytes.NewBuffer(buf[:0]) @@ -228,7 +229,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{junkedHeader}) + err = peer.SendAndCountBuffers([][]byte{junkedHeader}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -251,11 +252,19 @@ func (device *Device) SendHandshakeCookie( return err } + junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk() + if err != nil { + device.log.Errorf("%v - %v", device, err) + return err + } + var buf [MessageCookieReplySize]byte writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) + + junkedHeader = append(junkedHeader, writer.Bytes()...) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint) return nil } @@ -576,6 +585,14 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true + + junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet)) + if err != nil { + device.log.Errorf("%v - %v", device, err) + continue + } + + elem.packet = append(junkedHeader, elem.packet...) } bufs = append(bufs, elem.packet) } @@ -583,10 +600,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffers(bufs) + err := peer.SendAndCountBuffers(bufs) if dataSent { peer.timersDataSent() } + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) diff --git a/device/uapi.go b/device/uapi.go index 870bddc..e9f962a 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "sync" "time" + "github.com/amnezia-vpn/amneziawg-go/device/awg" "github.com/amnezia-vpn/amneziawg-go/ipc" ) @@ -97,33 +98,51 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } - if device.isAdvancedSecurityOn() { - if device.aSecCfg.junkPacketCount != 0 { - sendf("jc=%d", device.aSecCfg.junkPacketCount) + if device.isAWG() { + if device.awg.ASecCfg.JunkPacketCount != 0 { + sendf("jc=%d", device.awg.ASecCfg.JunkPacketCount) } - if device.aSecCfg.junkPacketMinSize != 0 { - sendf("jmin=%d", device.aSecCfg.junkPacketMinSize) + if device.awg.ASecCfg.JunkPacketMinSize != 0 { + sendf("jmin=%d", device.awg.ASecCfg.JunkPacketMinSize) } - if device.aSecCfg.junkPacketMaxSize != 0 { - sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize) + if device.awg.ASecCfg.JunkPacketMaxSize != 0 { + sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize) } - if device.aSecCfg.initPacketJunkSize != 0 { - sendf("s1=%d", device.aSecCfg.initPacketJunkSize) + if device.awg.ASecCfg.InitHeaderJunkSize != 0 { + sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize) } - if device.aSecCfg.responsePacketJunkSize != 0 { - sendf("s2=%d", device.aSecCfg.responsePacketJunkSize) + if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 { + sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize) } - if device.aSecCfg.initPacketMagicHeader != 0 { - sendf("h1=%d", device.aSecCfg.initPacketMagicHeader) + if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 { + sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize) } - if device.aSecCfg.responsePacketMagicHeader != 0 { - sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader) + if device.awg.ASecCfg.TransportHeaderJunkSize != 0 { + sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize) } - if device.aSecCfg.underloadPacketMagicHeader != 0 { - sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader) + if device.awg.ASecCfg.InitPacketMagicHeader != 0 { + sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader) } - if device.aSecCfg.transportPacketMagicHeader != 0 { - sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader) + if device.awg.ASecCfg.ResponsePacketMagicHeader != 0 { + sendf("h2=%d", device.awg.ASecCfg.ResponsePacketMagicHeader) + } + if device.awg.ASecCfg.UnderloadPacketMagicHeader != 0 { + sendf("h3=%d", device.awg.ASecCfg.UnderloadPacketMagicHeader) + } + if device.awg.ASecCfg.TransportPacketMagicHeader != 0 { + sendf("h4=%d", device.awg.ASecCfg.TransportPacketMagicHeader) + } + + specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields() + for _, field := range specialJunkIpcFields { + sendf("%s=%s", field.Key, field.Value) + } + controlledJunkIpcFields := device.awg.HandshakeHandler.ControlledJunk.IpcGetFields() + for _, field := range controlledJunkIpcFields { + sendf("%s=%s", field.Key, field.Value) + } + if device.awg.HandshakeHandler.ITimeout != 0 { + sendf("itime=%d", device.awg.HandshakeHandler.ITimeout/time.Second) } } @@ -180,13 +199,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempASecCfg := aSecCfgType{} + tempAwg := awg.Protocol{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { // Blank line means terminate operation. - err := device.handlePostConfig(&tempASecCfg) + err := device.handlePostConfig(&tempAwg) if err != nil { return err } @@ -217,7 +236,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { var err error if deviceConfig { - err = device.handleDeviceLine(key, value, &tempASecCfg) + err = device.handleDeviceLine(key, value, &tempAwg) } else { err = device.handlePeerLine(peer, key, value) } @@ -225,7 +244,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } - err = device.handlePostConfig(&tempASecCfg) + err = device.handlePostConfig(&tempAwg) if err != nil { return err } @@ -237,7 +256,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error { +func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error { switch key { case "private_key": var sk NoisePrivateKey @@ -278,7 +297,11 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy case "replace_peers": if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set replace_peers, invalid value: %v", + value, + ) } device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() @@ -286,80 +309,138 @@ func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgTy case "jc": junkPacketCount, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempASecCfg.junkPacketCount = junkPacketCount - tempASecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketCount = junkPacketCount + tempAwg.ASecCfg.IsSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempASecCfg.junkPacketMinSize = junkPacketMinSize - tempASecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketMinSize = junkPacketMinSize + tempAwg.ASecCfg.IsSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempASecCfg.junkPacketMaxSize = junkPacketMaxSize - tempASecCfg.isSet = true + tempAwg.ASecCfg.JunkPacketMaxSize = junkPacketMaxSize + tempAwg.ASecCfg.IsSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempASecCfg.initPacketJunkSize = initPacketJunkSize - tempASecCfg.isSet = true + tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize + tempAwg.ASecCfg.IsSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempASecCfg.responsePacketJunkSize = responsePacketJunkSize - tempASecCfg.isSet = true + tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize + tempAwg.ASecCfg.IsSet = true + + case "s3": + cookieReplyPacketJunkSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err) + } + device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size") + tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize + tempAwg.ASecCfg.IsSet = true + + case "s4": + transportPacketJunkSize, err := strconv.Atoi(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err) + } + device.log.Verbosef("UAPI: Updating transport_packet_junk_size") + tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize + tempAwg.ASecCfg.IsSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_magic_header %w", err) } - tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.ASecCfg.InitPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_magic_header %w", err) } - tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.ASecCfg.ResponsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse underload_packet_magic_header %w", err) } - tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempASecCfg.isSet = true + tempAwg.ASecCfg.UnderloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) + return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_magic_header %w", err) + } + tempAwg.ASecCfg.TransportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwg.ASecCfg.IsSet = true + case "i1", "i2", "i3", "i4", "i5": + if len(value) == 0 { + device.log.Verbosef("UAPI: received empty %s", key) + return nil } - tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempASecCfg.isSet = true + generators, err := awg.Parse(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) + } + device.log.Verbosef("UAPI: Updating %s", key) + tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators) + tempAwg.HandshakeHandler.IsSet = true + case "j1", "j2", "j3": + if len(value) == 0 { + device.log.Verbosef("UAPI: received empty %s", key) + return nil + } + + generators, err := awg.Parse(key, value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err) + } + device.log.Verbosef("UAPI: Updating %s", key) + + tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators) + tempAwg.HandshakeHandler.IsSet = true + case "itime": + if len(value) == 0 { + device.log.Verbosef("UAPI: received empty itime") + return nil + } + + itime, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "parse itime %w", err) + } + device.log.Verbosef("UAPI: Updating itime") + + tempAwg.HandshakeHandler.ITimeout = time.Duration(itime) * time.Second + tempAwg.HandshakeHandler.IsSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } @@ -432,7 +513,11 @@ func (device *Device) handlePeerLine( case "update_only": // allow disabling of creation if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set update only, invalid value: %v", + value, + ) } if peer.created && !peer.dummy { device.RemovePeer(peer.handshake.remoteStatic) @@ -478,7 +563,11 @@ func (device *Device) handlePeerLine( secs, err := strconv.ParseUint(value, 10, 16) if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to set persistent keepalive interval: %w", + err, + ) } old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) @@ -489,7 +578,11 @@ func (device *Device) handlePeerLine( case "replace_allowed_ips": device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + return ipcErrorf( + ipc.IpcErrorInvalid, + "failed to replace allowedips, invalid value: %v", + value, + ) } if peer.dummy { return nil @@ -568,7 +661,11 @@ func (device *Device) IpcHandle(socket net.Conn) { return } if nextByte != '\n' { - err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) + err = ipcErrorf( + ipc.IpcErrorInvalid, + "trailing character in UAPI get: %q", + nextByte, + ) break } err = device.IpcGetOperation(buffered.Writer) diff --git a/go.mod b/go.mod index 99569f3..5e5f34d 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,23 @@ module github.com/amnezia-vpn/amneziawg-go -go 1.24 +go 1.24.4 require ( + github.com/stretchr/testify v1.10.0 + github.com/tevino/abool v1.2.0 github.com/tevino/abool/v2 v2.1.0 - golang.org/x/crypto v0.37.0 - golang.org/x/net v0.39.0 - golang.org/x/sys v0.32.0 + go.uber.org/atomic v1.11.0 + golang.org/x/crypto v0.39.0 + golang.org/x/net v0.41.0 + golang.org/x/sys v0.33.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c + gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.3 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/time v0.9.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index b8ac0bd..6b8f36b 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,40 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA= +github.com/tevino/abool v1.2.0/go.mod h1:qc66Pna1RiIsPa7O4Egxxs9OqkuxDX55zznh9K07Tzg= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= -golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= -golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= -gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489 h1:ze1vwAdliUAr68RQ5NtufWaXaOg8WUO2OACzEV+TNdE= +gvisor.dev/gvisor v0.0.0-20231202080848-1f7806d17489/go.mod h1:10sU+Uh5KKNv1+2x2A0Gvzt8FjD3ASIhorV3YsauXhk= +gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk= +gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= +gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f h1:zmc4cHEcCudRt2O8VsCW7nYLfAsbVY2i910/DAop1TM= +gvisor.dev/gvisor v0.0.0-20250606233247-e3c4c4cad86f/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=