add codec generation/parsing

This commit is contained in:
Mark Puha 2025-02-08 19:00:00 +01:00
parent 3015f3ea20
commit 32470fa04e
3 changed files with 48 additions and 6 deletions

View file

@ -11,6 +11,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/amnezia-vpn/amneziawg-go/adapter"
"github.com/amnezia-vpn/amneziawg-go/conn" "github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/ipc" "github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter" "github.com/amnezia-vpn/amneziawg-go/ratelimiter"
@ -92,11 +93,13 @@ type Device struct {
closed chan struct{} closed chan struct{}
log *Logger log *Logger
isASecOn abool.AtomicBool isASecOn abool.AtomicBool
aSecMux sync.RWMutex aSecMux sync.RWMutex
aSecCfg aSecCfgType aSecCfg aSecCfgType
junkCreator junkCreator junkCreator junkCreator
luaAdapter *adapter.Lua
packetCounter atomic.Int64
} }
type aSecCfgType struct { type aSecCfgType struct {
@ -428,6 +431,9 @@ func (device *Device) Close() {
device.resetProtocol() device.resetProtocol()
if device.luaAdapter != nil {
device.luaAdapter.Close()
}
device.log.Verbosef("Device closed") device.log.Verbosef("Device closed")
close(device.closed) close(device.closed)
} }

View file

@ -137,8 +137,14 @@ func (device *Device) RoutineReceiveIncoming(
} }
// check size of packet // check size of packet
packet := bufsArrs[i][:size] packet := bufsArrs[i][:size]
if device.luaAdapter != nil {
packet, err = device.luaAdapter.Parse(packet)
if err != nil {
device.log.Verbosef("Couldn't parse message; reason: %v", err)
continue
}
}
var msgType uint32 var msgType uint32
if device.isAdvancedSecurityOn() { if device.isAdvancedSecurityOn() {
if assumedMsgType, ok := packetSizeToMsgType[size]; ok { if assumedMsgType, ok := packetSizeToMsgType[size]; ok {

View file

@ -175,6 +175,10 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.cookieGenerator.AddMacs(packet) peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...) junkedHeader = append(junkedHeader, packet...)
if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
return err
}
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
@ -233,6 +237,10 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.cookieGenerator.AddMacs(packet) peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...) junkedHeader = append(junkedHeader, packet...)
if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
return err
}
err = peer.BeginSymmetricSession() err = peer.BeginSymmetricSession()
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
@ -277,8 +285,13 @@ func (device *Device) SendHandshakeCookie(
var buf [MessageCookieReplySize]byte var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
packet := writer.Bytes()
if packet, err = device.codecPacket(packet); err != nil {
return err
}
// TODO: allocation could be avoided // TODO: allocation could be avoided
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
return nil return nil
} }
@ -534,6 +547,18 @@ func calculatePaddingSize(packetSize, mtu int) int {
return paddedSize - lastUnit return paddedSize - lastUnit
} }
func (device *Device) codecPacket(packet []byte) ([]byte, error) {
if device.luaAdapter != nil {
var err error
packet, err = device.luaAdapter.Generate(packet, device.packetCounter.Add(1))
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
}
}
return packet, nil
}
/* Encrypts the elements in the queue /* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex) * and marks them for sequential consumption (by releasing the mutex)
* *
@ -578,6 +603,11 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet, elem.packet,
nil, nil,
) )
// TODO: check
var err error
if elem.packet, err = device.codecPacket(elem.packet); err != nil {
continue
}
} }
elemsContainer.Unlock() elemsContainer.Unlock()
} }