feat: add s3, s4

This commit is contained in:
Mark Puha 2025-07-02 19:59:53 +02:00
parent 5680162c72
commit 05fbf0feb0
10 changed files with 235 additions and 108 deletions

View file

@ -1,11 +1,27 @@
package awg
import (
"bytes"
"sync"
"github.com/tevino/abool"
)
type aSecCfgType struct {
IsSet bool
JunkPacketCount int
JunkPacketMinSize int
JunkPacketMaxSize int
InitHeaderJunkSize int
ResponseHeaderJunkSize int
CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int
InitPacketMagicHeader uint32
ResponsePacketMagicHeader uint32
UnderloadPacketMagicHeader uint32
TransportPacketMagicHeader uint32
}
type Protocol struct {
IsASecOn abool.AtomicBool
// TODO: revision the need of the mutex
@ -16,15 +32,36 @@ type Protocol struct {
HandshakeHandler SpecialHandshakeHandler
}
type aSecCfgType struct {
IsSet bool
JunkPacketCount int
JunkPacketMinSize int
JunkPacketMaxSize int
InitPacketJunkSize int
ResponsePacketJunkSize int
InitPacketMagicHeader uint32
ResponsePacketMagicHeader uint32
UnderloadPacketMagicHeader uint32
TransportPacketMagicHeader uint32
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.InitHeaderJunkSize)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.ResponseHeaderJunkSize)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.CookieReplyHeaderJunkSize)
}
func (protocol *Protocol) CreateTransportHeaderJunk() ([]byte, error) {
return protocol.createHeaderJunk(protocol.ASecCfg.TransportHeaderJunkSize)
}
func (protocol *Protocol) createHeaderJunk(junkSize int) ([]byte, error) {
var junk []byte
protocol.ASecMux.RLock()
if junkSize != 0 {
buf := make([]byte, 0, junkSize)
writer := bytes.NewBuffer(buf[:0])
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
if err != nil {
protocol.ASecMux.RUnlock()
return nil, err
}
junk = writer.Bytes()
}
protocol.ASecMux.RUnlock()
return junk, nil
}

View file

@ -12,6 +12,7 @@ type junkCreator struct {
cha8Rand *v2.ChaCha8
}
// TODO: refactor param to only pass the junk related params
func NewJunkCreator(aSecCfg aSecCfgType) (junkCreator, error) {
buf := make([]byte, 32)
_, err := crand.Read(buf)

View file

@ -12,8 +12,8 @@ func setUpJunkCreator(t *testing.T) (junkCreator, error) {
JunkPacketCount: 5,
JunkPacketMinSize: 500,
JunkPacketMaxSize: 1000,
InitPacketJunkSize: 30,
ResponsePacketJunkSize: 40,
InitHeaderJunkSize: 30,
ResponseHeaderJunkSize: 40,
InitPacketMagicHeader: 123456,
ResponsePacketMagicHeader: 67543,
UnderloadPacketMagicHeader: 32345,

View file

@ -646,38 +646,111 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
isASecOn = true
}
if MessageInitiationSize+tempAwg.ASecCfg.InitPacketJunkSize >= MaxSegmentSize {
newInitSize := MessageInitiationSize + tempAwg.ASecCfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.InitPacketJunkSize,
tempAwg.ASecCfg.InitHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.InitPacketJunkSize = tempAwg.ASecCfg.InitPacketJunkSize
device.awg.ASecCfg.InitHeaderJunkSize = tempAwg.ASecCfg.InitHeaderJunkSize
}
if tempAwg.ASecCfg.InitPacketJunkSize != 0 {
if tempAwg.ASecCfg.InitHeaderJunkSize != 0 {
isASecOn = true
}
if MessageResponseSize+tempAwg.ASecCfg.ResponsePacketJunkSize >= MaxSegmentSize {
newResponseSize := MessageResponseSize + tempAwg.ASecCfg.ResponseHeaderJunkSize
if newResponseSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.ResponsePacketJunkSize,
tempAwg.ASecCfg.ResponseHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.ResponsePacketJunkSize = tempAwg.ASecCfg.ResponsePacketJunkSize
device.awg.ASecCfg.ResponseHeaderJunkSize = tempAwg.ASecCfg.ResponseHeaderJunkSize
}
if tempAwg.ASecCfg.ResponsePacketJunkSize != 0 {
if tempAwg.ASecCfg.ResponseHeaderJunkSize != 0 {
isASecOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.ASecCfg.CookieReplyHeaderJunkSize
if newCookieSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.CookieReplyHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.CookieReplyHeaderJunkSize = tempAwg.ASecCfg.CookieReplyHeaderJunkSize
}
if tempAwg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
isASecOn = true
}
newTransportSize := MessageTransportSize + tempAwg.ASecCfg.TransportHeaderJunkSize
if newTransportSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.ASecCfg.TransportHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.ASecCfg.TransportHeaderJunkSize = tempAwg.ASecCfg.TransportHeaderJunkSize
}
if tempAwg.ASecCfg.TransportHeaderJunkSize != 0 {
isASecOn = true
}
isSameSizeMap := map[int]struct{}{
newInitSize: {},
newResponseSize: {},
newCookieSize: {},
newTransportSize: {},
}
if len(isSameSizeMap) != 4 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`,
newInitSize,
newResponseSize,
newCookieSize,
newTransportSize,
),
)
} else {
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
newCookieSize: MessageCookieReplyType,
newTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.awg.ASecCfg.InitHeaderJunkSize,
MessageResponseType: device.awg.ASecCfg.ResponseHeaderJunkSize,
MessageCookieReplyType: device.awg.ASecCfg.CookieReplyHeaderJunkSize,
MessageTransportType: device.awg.ASecCfg.TransportHeaderJunkSize,
}
}
if tempAwg.ASecCfg.InitPacketMagicHeader > 4 {
isASecOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
@ -718,7 +791,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
MessageTransportType = DefaultMessageTransportType
}
isSameMap := map[uint32]struct{}{
isSameHeaderMap := map[uint32]struct{}{
MessageInitiationType: {},
MessageResponseType: {},
MessageCookieReplyType: {},
@ -726,7 +799,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}
// size will be different if same values
if len(isSameMap) != 4 {
if len(isSameHeaderMap) != 4 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
@ -738,33 +811,6 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
)
}
newInitSize := MessageInitiationSize + device.awg.ASecCfg.InitPacketJunkSize
newResponseSize := MessageResponseSize + device.awg.ASecCfg.ResponsePacketJunkSize
if newInitSize == newResponseSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`new init size:%d; and new response size:%d; should differ`,
newInitSize,
newResponseSize,
),
)
} else {
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
MessageCookieReplySize: MessageCookieReplyType,
MessageTransportSize: MessageTransportType,
}
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.awg.ASecCfg.InitPacketJunkSize,
MessageResponseType: device.awg.ASecCfg.ResponsePacketJunkSize,
MessageCookieReplyType: 0,
MessageTransportType: 0,
}
}
device.awg.IsASecOn.SetTo(isASecOn)
var err error
device.awg.JunkCreator, err = awg.NewJunkCreator(device.awg.ASecCfg)

View file

@ -17,6 +17,7 @@ import (
"os/signal"
"runtime"
"runtime/pprof"
"strconv"
"sync"
"testing"
"time"
@ -129,6 +130,7 @@ func (pair *testPair) Send(
tb testing.TB,
ping SendDirection,
done chan struct{},
optTransportJunk ...int,
) {
tb.Helper()
p0, p1 := pair[0], pair[1]
@ -136,6 +138,12 @@ func (pair *testPair) Send(
// pong is the new ping
p0, p1 = p1, p0
}
transportJunk := 0
if len(optTransportJunk) > 0 {
transportJunk = optTransportJunk[0]
}
msg := tuntest.Ping(p0.ip, p1.ip)
p1.tun.Outbound <- msg
timer := time.NewTimer(6 * time.Second)
@ -143,7 +151,10 @@ func (pair *testPair) Send(
var err error
select {
case msgRecv := <-p0.tun.Inbound:
if !bytes.Equal(msg, msgRecv) {
fmt.Printf("%x\n", msg)
fmt.Printf("%x\n", msgRecv[transportJunk:])
fmt.Printf("%x\n", msgRecv)
if !bytes.Equal(msg, msgRecv[transportJunk:]) {
err = fmt.Errorf("%s did not transit correctly", ping)
}
case <-timer.C:
@ -226,22 +237,27 @@ func TestTwoDevicePing(t *testing.T) {
// Run test with -race=false to avoid the race for setting the default msgTypes 2 times
func TestAWGDevicePing(t *testing.T) {
goroutineLeakCheck(t)
transportJunk := 5
pair := genTestPair(t, true,
"jc", "5",
"jmin", "500",
"jmax", "1000",
"s1", "30",
"s2", "40",
"h1", "123456",
"h2", "67543",
"h4", "32345",
"h3", "123123",
// "s1", "30",
// "s2", "40",
"s3", "50",
"s4", strconv.Itoa(transportJunk),
// "h1", "123456",
// "h2", "67543",
// "h3", "123123",
// "h4", "32345",
)
t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil)
pair.Send(t, Ping, nil, transportJunk)
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
pair.Send(t, Pong, nil)
pair.Send(t, Pong, nil, transportJunk)
})
}

View file

@ -82,9 +82,10 @@ const (
MessageTransportOffsetContent = 16
)
var packetSizeToMsgType map[int]uint32
var msgTypeToJunkSize map[uint32]int
var (
packetSizeToMsgType map[int]uint32
msgTypeToJunkSize map[uint32]int
)
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder

View file

@ -114,6 +114,16 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
err := peer.SendBuffers(buffers)
if err == nil {
awg.PacketCounter.Add(uint64(len(buffers)))
return nil
}
return err
}
func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@ -145,16 +155,6 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
return err
}
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
err := peer.SendBuffers(buffers)
if err == nil {
awg.PacketCounter.Add(uint64(len(buffers)))
return nil
}
return err
}
func (peer *Peer) String() string {
// The awful goo that follows is identical to:
//

View file

@ -157,12 +157,15 @@ func (device *Device) RoutineReceiveIncoming(
msgType = binary.LittleEndian.Uint32(packet[:4])
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
transportJunkSize := device.awg.ASecCfg.TransportHeaderJunkSize
msgType = binary.LittleEndian.Uint32(packet[transportJunkSize : transportJunkSize+4])
if msgType != MessageTransportType {
// probably a junk packet
device.log.Verbosef("aSec: Received message with unknown type: %d", msgType)
continue
}
packet = packet[transportJunkSize:]
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])

View file

@ -124,9 +124,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
return err
}
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.version >= VersionAwg {
var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake {
@ -163,19 +163,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
}
}
peer.device.awg.ASecMux.RLock()
if peer.device.awg.ASecCfg.InitPacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.ASecCfg.InitPacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.InitPacketJunkSize)
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
peer.device.awg.ASecMux.RUnlock()
return err
}
junkedHeader = writer.Bytes()
junkedHeader, err = peer.device.awg.CreateInitHeaderJunk()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
peer.device.awg.ASecMux.RUnlock()
}
var buf [MessageInitiationSize]byte
@ -211,22 +203,13 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
return err
}
var junkedHeader []byte
if peer.device.isAWG() {
peer.device.awg.ASecMux.RLock()
if peer.device.awg.ASecCfg.ResponsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.awg.ASecCfg.ResponsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0])
err = peer.device.awg.JunkCreator.AppendJunk(writer, peer.device.awg.ASecCfg.ResponsePacketJunkSize)
if err != nil {
peer.device.awg.ASecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
junkedHeader = writer.Bytes()
}
peer.device.awg.ASecMux.RUnlock()
junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
@ -269,11 +252,19 @@ func (device *Device) SendHandshakeCookie(
return err
}
junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk()
if err != nil {
device.log.Errorf("%v - %v", device, err)
return err
}
var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
junkedHeader = append(junkedHeader, writer.Bytes()...)
// TODO: allocation could be avoided
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint)
return nil
}
@ -593,6 +584,13 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
junkedHeader, err := device.awg.CreateTransportHeaderJunk()
if err != nil {
device.log.Errorf("%v - %v", device, err)
continue
}
elem.packet = append(junkedHeader, elem.packet...)
}
bufs = append(bufs, elem.packet)
}
@ -604,6 +602,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
if dataSent {
peer.timersDataSent()
}
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)

View file

@ -108,11 +108,17 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
if device.awg.ASecCfg.JunkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.ASecCfg.JunkPacketMaxSize)
}
if device.awg.ASecCfg.InitPacketJunkSize != 0 {
sendf("s1=%d", device.awg.ASecCfg.InitPacketJunkSize)
if device.awg.ASecCfg.InitHeaderJunkSize != 0 {
sendf("s1=%d", device.awg.ASecCfg.InitHeaderJunkSize)
}
if device.awg.ASecCfg.ResponsePacketJunkSize != 0 {
sendf("s2=%d", device.awg.ASecCfg.ResponsePacketJunkSize)
if device.awg.ASecCfg.ResponseHeaderJunkSize != 0 {
sendf("s2=%d", device.awg.ASecCfg.ResponseHeaderJunkSize)
}
if device.awg.ASecCfg.CookieReplyHeaderJunkSize != 0 {
sendf("s3=%d", device.awg.ASecCfg.CookieReplyHeaderJunkSize)
}
if device.awg.ASecCfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.ASecCfg.TransportHeaderJunkSize)
}
if device.awg.ASecCfg.InitPacketMagicHeader != 0 {
sendf("h1=%d", device.awg.ASecCfg.InitPacketMagicHeader)
@ -333,7 +339,7 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempAwg.ASecCfg.InitPacketJunkSize = initPacketJunkSize
tempAwg.ASecCfg.InitHeaderJunkSize = initPacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "s2":
@ -342,7 +348,25 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempAwg.ASecCfg.ResponsePacketJunkSize = responsePacketJunkSize
tempAwg.ASecCfg.ResponseHeaderJunkSize = responsePacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "s3":
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
tempAwg.ASecCfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "s4":
transportPacketJunkSize, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.ASecCfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.ASecCfg.IsSet = true
case "h1":