From 553060f804a4429356037143a3dacd71d2e74a19 Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Sun, 9 Feb 2025 09:21:01 +0100
Subject: [PATCH] Send msg type to code & add default msg types
---
adapter/lua.go | 5 +++--
adapter/lua_test.go | 20 ++++++++++----------
device/device.go | 17 +++++++++--------
device/noise-protocol.go | 15 +++++++++++----
device/send.go | 14 +++++++-------
5 files changed, 40 insertions(+), 31 deletions(-)
diff --git a/adapter/lua.go b/adapter/lua.go
index 47ce7b2..e44ca70 100644
--- a/adapter/lua.go
+++ b/adapter/lua.go
@@ -36,13 +36,14 @@ func (l *Lua) Close() {
l.state.Close()
}
-func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) {
+func (l *Lua) Generate(msgType int64, data []byte, counter int64) ([]byte, error) {
l.state.GetGlobal("d_gen")
+ l.state.PushInteger(msgType)
l.state.PushBytes(data)
l.state.PushInteger(counter)
- if err := l.state.Call(2, 1); err != nil {
+ if err := l.state.Call(3, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
diff --git a/adapter/lua_test.go b/adapter/lua_test.go
index 08d5420..828b555 100644
--- a/adapter/lua_test.go
+++ b/adapter/lua_test.go
@@ -7,17 +7,17 @@ import (
func newLua() *Lua {
lua, _ := NewLua(LuaParams{
/*
- function d_gen(data, counter)
- local header = "header"
- return counter .. header .. data
- end
+ function d_gen(msg_type, data, counter)
+ local header = "header"
+ return counter .. header .. data
+ end
- function d_parse(data)
- local header = "10header"
- return string.sub(data, #header+1)
- end
+ function d_parse(data)
+ local header = "10header"
+ return string.sub(data, #header+1)
+ end
*/
- Base64LuaCode: "ZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==",
+ Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjEwaGVhZGVyIgoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI2hlYWRlcisxKQplbmQK",
})
return lua
}
@@ -26,7 +26,7 @@ func TestLua_Generate(t *testing.T) {
t.Run("", func(t *testing.T) {
l := newLua()
defer l.Close()
- got, err := l.Generate([]byte("test"), 10)
+ got, err := l.Generate(1, []byte("test"), 10)
if err != nil {
t.Errorf(
"Lua.Generate() error = %v, wantErr %v",
diff --git a/device/device.go b/device/device.go
index 8178667..aaddcbc 100644
--- a/device/device.go
+++ b/device/device.go
@@ -433,6 +433,7 @@ func (device *Device) Close() {
if device.luaAdapter != nil {
device.luaAdapter.Close()
+ device.luaAdapter = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@@ -585,10 +586,10 @@ func (device *Device) isAdvancedSecurityOn() bool {
func (device *Device) resetProtocol() {
// restore default message type values
- MessageInitiationType = 1
- MessageResponseType = 2
- MessageCookieReplyType = 3
- MessageTransportType = 4
+ MessageInitiationType = DefaultMessageInitiationType
+ MessageResponseType = DefaultMessageResponseType
+ MessageCookieReplyType = DefaultMessageCookieReplyType
+ MessageTransportType = DefaultMessageTransportType
}
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
@@ -721,7 +722,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default init type")
- MessageInitiationType = 1
+ MessageInitiationType = DefaultMessageInitiationType
}
if tempASecCfg.responsePacketMagicHeader > 4 {
@@ -731,7 +732,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default response type")
- MessageResponseType = 2
+ MessageResponseType = DefaultMessageResponseType
}
if tempASecCfg.underloadPacketMagicHeader > 4 {
@@ -741,7 +742,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default underload type")
- MessageCookieReplyType = 3
+ MessageCookieReplyType = DefaultMessageCookieReplyType
}
if tempASecCfg.transportPacketMagicHeader > 4 {
@@ -751,7 +752,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
} else {
device.log.Verbosef("UAPI: Using default transport type")
- MessageTransportType = 4
+ MessageTransportType = DefaultMessageTransportType
}
isSameMap := map[uint32]bool{}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 1289249..d818238 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -52,11 +52,18 @@ const (
WGLabelCookie = "cookie--"
)
+const (
+ DefaultMessageInitiationType uint32 = 1
+ DefaultMessageResponseType uint32 = 2
+ DefaultMessageCookieReplyType uint32 = 3
+ DefaultMessageTransportType uint32 = 4
+)
+
var (
- MessageInitiationType uint32 = 1
- MessageResponseType uint32 = 2
- MessageCookieReplyType uint32 = 3
- MessageTransportType uint32 = 4
+ MessageInitiationType uint32 = DefaultMessageInitiationType
+ MessageResponseType uint32 = DefaultMessageResponseType
+ MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
+ MessageTransportType uint32 = DefaultMessageTransportType
)
const (
diff --git a/device/send.go b/device/send.go
index c0b9ad4..ac05313 100644
--- a/device/send.go
+++ b/device/send.go
@@ -175,7 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
- if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
+ if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil {
return err
}
@@ -237,7 +237,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
- if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
+ if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil {
return err
}
@@ -286,7 +286,7 @@ func (device *Device) SendHandshakeCookie(
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
packet := writer.Bytes()
- if packet, err = device.codecPacket(packet); err != nil {
+ if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil {
return err
}
@@ -547,10 +547,10 @@ func calculatePaddingSize(packetSize, mtu int) int {
return paddedSize - lastUnit
}
-func (device *Device) codecPacket(packet []byte) ([]byte, error) {
+func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
if device.luaAdapter != nil {
var err error
- packet, err = device.luaAdapter.Generate(packet, device.packetCounter.Add(1))
+ packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1))
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
@@ -603,9 +603,9 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet,
nil,
)
- // TODO: check
+
var err error
- if elem.packet, err = device.codecPacket(elem.packet); err != nil {
+ if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
continue
}
}