mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-08-02 01:42:54 +02:00
Merge 1896d9ba3f
into 1abd24b5b9
This commit is contained in:
commit
de1ca9582d
6 changed files with 21 additions and 11 deletions
|
@ -31,12 +31,12 @@ type SpecialHandshakeHandler struct {
|
||||||
IsSet bool
|
IsSet bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler *SpecialHandshakeHandler) Validate() error {
|
func (handler *SpecialHandshakeHandler) Validate(maxSegmentSize int) error {
|
||||||
var errs []error
|
var errs []error
|
||||||
if err := handler.SpecialJunk.Validate(); err != nil {
|
if err := handler.SpecialJunk.Validate(maxSegmentSize); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
if err := handler.ControlledJunk.Validate(); err != nil {
|
if err := handler.ControlledJunk.Validate(maxSegmentSize); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
return errors.Join(errs...)
|
return errors.Join(errs...)
|
||||||
|
|
|
@ -57,3 +57,7 @@ func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
|
||||||
Value: tg.tagValue,
|
Value: tg.tagValue,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tg *TagJunkPacketGenerator) Size() int {
|
||||||
|
return tg.packetSize
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package awg
|
package awg
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
type TagJunkPacketGenerators struct {
|
type TagJunkPacketGenerators struct {
|
||||||
tagGenerators []TagJunkPacketGenerator
|
tagGenerators []TagJunkPacketGenerator
|
||||||
|
@ -20,7 +22,7 @@ func (generators *TagJunkPacketGenerators) IsDefined() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate that packets were defined consecutively
|
// validate that packets were defined consecutively
|
||||||
func (generators *TagJunkPacketGenerators) Validate() error {
|
func (generators *TagJunkPacketGenerators) Validate(maxSegmentSize int) error {
|
||||||
seen := make([]bool, len(generators.tagGenerators))
|
seen := make([]bool, len(generators.tagGenerators))
|
||||||
for _, generator := range generators.tagGenerators {
|
for _, generator := range generators.tagGenerators {
|
||||||
index, err := generator.nameIndex()
|
index, err := generator.nameIndex()
|
||||||
|
@ -32,6 +34,10 @@ func (generators *TagJunkPacketGenerators) Validate() error {
|
||||||
} else {
|
} else {
|
||||||
seen[index-1] = true
|
seen[index-1] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if generator.Size() > maxSegmentSize {
|
||||||
|
return fmt.Errorf("junk packet %s must not exceed %d bytes", generator.name, maxSegmentSize)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, found := range seen {
|
for _, found := range seen {
|
||||||
|
|
|
@ -91,7 +91,7 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||||
generators.AppendGenerator(gen)
|
generators.AppendGenerator(gen)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := generators.Validate()
|
err := generators.Validate(1500)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), tt.errMsg)
|
require.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
|
|
@ -819,7 +819,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if tempAwg.HandshakeHandler.IsSet {
|
if tempAwg.HandshakeHandler.IsSet {
|
||||||
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
|
if err := tempAwg.HandshakeHandler.Validate(MaxSegmentSize); err != nil {
|
||||||
errs = append(errs, ipcErrorf(
|
errs = append(errs, ipcErrorf(
|
||||||
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -406,12 +406,12 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
generators, err := awg.Parse(key, value)
|
generator, err := awg.Parse(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("UAPI: Updating %s", key)
|
device.log.Verbosef("UAPI: Updating %s", key)
|
||||||
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
|
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generator)
|
||||||
tempAwg.HandshakeHandler.IsSet = true
|
tempAwg.HandshakeHandler.IsSet = true
|
||||||
case "j1", "j2", "j3":
|
case "j1", "j2", "j3":
|
||||||
if len(value) == 0 {
|
if len(value) == 0 {
|
||||||
|
@ -419,13 +419,13 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
generators, err := awg.Parse(key, value)
|
generator, err := awg.Parse(key, value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
||||||
}
|
}
|
||||||
device.log.Verbosef("UAPI: Updating %s", key)
|
device.log.Verbosef("UAPI: Updating %s", key)
|
||||||
|
|
||||||
tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators)
|
tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generator)
|
||||||
tempAwg.HandshakeHandler.IsSet = true
|
tempAwg.HandshakeHandler.IsSet = true
|
||||||
case "itime":
|
case "itime":
|
||||||
if len(value) == 0 {
|
if len(value) == 0 {
|
||||||
|
|
Loading…
Add table
Reference in a new issue