From 553060f804a4429356037143a3dacd71d2e74a19 Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Sun, 9 Feb 2025 09:21:01 +0100 Subject: [PATCH] Send msg type to code & add default msg types --- adapter/lua.go | 5 +++-- adapter/lua_test.go | 20 ++++++++++---------- device/device.go | 17 +++++++++-------- device/noise-protocol.go | 15 +++++++++++---- device/send.go | 14 +++++++------- 5 files changed, 40 insertions(+), 31 deletions(-) diff --git a/adapter/lua.go b/adapter/lua.go index 47ce7b2..e44ca70 100644 --- a/adapter/lua.go +++ b/adapter/lua.go @@ -36,13 +36,14 @@ func (l *Lua) Close() { l.state.Close() } -func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) { +func (l *Lua) Generate(msgType int64, data []byte, counter int64) ([]byte, error) { l.state.GetGlobal("d_gen") + l.state.PushInteger(msgType) l.state.PushBytes(data) l.state.PushInteger(counter) - if err := l.state.Call(2, 1); err != nil { + if err := l.state.Call(3, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } diff --git a/adapter/lua_test.go b/adapter/lua_test.go index 08d5420..828b555 100644 --- a/adapter/lua_test.go +++ b/adapter/lua_test.go @@ -7,17 +7,17 @@ import ( func newLua() *Lua { lua, _ := NewLua(LuaParams{ /* - function d_gen(data, counter) - local header = "header" - return counter .. header .. data - end + function d_gen(msg_type, data, counter) + local header = "header" + return counter .. header .. data + end - function d_parse(data) - local header = "10header" - return string.sub(data, #header+1) - end + function d_parse(data) + local header = "10header" + return string.sub(data, #header+1) + end */ - Base64LuaCode: "ZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==", + Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjEwaGVhZGVyIgoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI2hlYWRlcisxKQplbmQK", }) return lua } @@ -26,7 +26,7 @@ func TestLua_Generate(t *testing.T) { t.Run("", func(t *testing.T) { l := newLua() defer l.Close() - got, err := l.Generate([]byte("test"), 10) + got, err := l.Generate(1, []byte("test"), 10) if err != nil { t.Errorf( "Lua.Generate() error = %v, wantErr %v", diff --git a/device/device.go b/device/device.go index 8178667..aaddcbc 100644 --- a/device/device.go +++ b/device/device.go @@ -433,6 +433,7 @@ func (device *Device) Close() { if device.luaAdapter != nil { device.luaAdapter.Close() + device.luaAdapter = nil } device.log.Verbosef("Device closed") close(device.closed) @@ -585,10 +586,10 @@ func (device *Device) isAdvancedSecurityOn() bool { func (device *Device) resetProtocol() { // restore default message type values - MessageInitiationType = 1 - MessageResponseType = 2 - MessageCookieReplyType = 3 - MessageTransportType = 4 + MessageInitiationType = DefaultMessageInitiationType + MessageResponseType = DefaultMessageResponseType + MessageCookieReplyType = DefaultMessageCookieReplyType + MessageTransportType = DefaultMessageTransportType } func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { @@ -721,7 +722,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { MessageInitiationType = device.aSecCfg.initPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default init type") - MessageInitiationType = 1 + MessageInitiationType = DefaultMessageInitiationType } if tempASecCfg.responsePacketMagicHeader > 4 { @@ -731,7 +732,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { MessageResponseType = device.aSecCfg.responsePacketMagicHeader } else { device.log.Verbosef("UAPI: Using default response type") - MessageResponseType = 2 + MessageResponseType = DefaultMessageResponseType } if tempASecCfg.underloadPacketMagicHeader > 4 { @@ -741,7 +742,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default underload type") - MessageCookieReplyType = 3 + MessageCookieReplyType = DefaultMessageCookieReplyType } if tempASecCfg.transportPacketMagicHeader > 4 { @@ -751,7 +752,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) { MessageTransportType = device.aSecCfg.transportPacketMagicHeader } else { device.log.Verbosef("UAPI: Using default transport type") - MessageTransportType = 4 + MessageTransportType = DefaultMessageTransportType } isSameMap := map[uint32]bool{} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 1289249..d818238 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -52,11 +52,18 @@ const ( WGLabelCookie = "cookie--" ) +const ( + DefaultMessageInitiationType uint32 = 1 + DefaultMessageResponseType uint32 = 2 + DefaultMessageCookieReplyType uint32 = 3 + DefaultMessageTransportType uint32 = 4 +) + var ( - MessageInitiationType uint32 = 1 - MessageResponseType uint32 = 2 - MessageCookieReplyType uint32 = 3 - MessageTransportType uint32 = 4 + MessageInitiationType uint32 = DefaultMessageInitiationType + MessageResponseType uint32 = DefaultMessageResponseType + MessageCookieReplyType uint32 = DefaultMessageCookieReplyType + MessageTransportType uint32 = DefaultMessageTransportType ) const ( diff --git a/device/send.go b/device/send.go index c0b9ad4..ac05313 100644 --- a/device/send.go +++ b/device/send.go @@ -175,7 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) - if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil { + if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil { return err } @@ -237,7 +237,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.cookieGenerator.AddMacs(packet) junkedHeader = append(junkedHeader, packet...) - if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil { + if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil { return err } @@ -286,7 +286,7 @@ func (device *Device) SendHandshakeCookie( writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) packet := writer.Bytes() - if packet, err = device.codecPacket(packet); err != nil { + if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil { return err } @@ -547,10 +547,10 @@ func calculatePaddingSize(packetSize, mtu int) int { return paddedSize - lastUnit } -func (device *Device) codecPacket(packet []byte) ([]byte, error) { +func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) { if device.luaAdapter != nil { var err error - packet, err = device.luaAdapter.Generate(packet, device.packetCounter.Add(1)) + packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1)) if err != nil { device.log.Errorf("%v - Failed to run codec generate: %v", device, err) return nil, err @@ -603,9 +603,9 @@ func (device *Device) RoutineEncryption(id int) { elem.packet, nil, ) - // TODO: check + var err error - if elem.packet, err = device.codecPacket(elem.packet); err != nil { + if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil { continue } }