mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-16 22:16:55 +02:00
Send msg type to code & add default msg types
This commit is contained in:
parent
32470fa04e
commit
553060f804
5 changed files with 40 additions and 31 deletions
|
@ -36,13 +36,14 @@ func (l *Lua) Close() {
|
|||
l.state.Close()
|
||||
}
|
||||
|
||||
func (l *Lua) Generate(data []byte, counter int64) ([]byte, error) {
|
||||
func (l *Lua) Generate(msgType int64, data []byte, counter int64) ([]byte, error) {
|
||||
l.state.GetGlobal("d_gen")
|
||||
|
||||
l.state.PushInteger(msgType)
|
||||
l.state.PushBytes(data)
|
||||
l.state.PushInteger(counter)
|
||||
|
||||
if err := l.state.Call(2, 1); err != nil {
|
||||
if err := l.state.Call(3, 1); err != nil {
|
||||
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -7,17 +7,17 @@ import (
|
|||
func newLua() *Lua {
|
||||
lua, _ := NewLua(LuaParams{
|
||||
/*
|
||||
function d_gen(data, counter)
|
||||
local header = "header"
|
||||
return counter .. header .. data
|
||||
end
|
||||
function d_gen(msg_type, data, counter)
|
||||
local header = "header"
|
||||
return counter .. header .. data
|
||||
end
|
||||
|
||||
function d_parse(data)
|
||||
local header = "10header"
|
||||
return string.sub(data, #header+1)
|
||||
end
|
||||
function d_parse(data)
|
||||
local header = "10header"
|
||||
return string.sub(data, #header+1)
|
||||
end
|
||||
*/
|
||||
Base64LuaCode: "ZnVuY3Rpb24gZF9nZW4oZGF0YSwgY291bnRlcikKCWxvY2FsIGhlYWRlciA9ICJoZWFkZXIiCglyZXR1cm4gY291bnRlciAuLiBoZWFkZXIgLi4gZGF0YQplbmQKCmZ1bmN0aW9uIGRfcGFyc2UoZGF0YSkKCWxvY2FsIGhlYWRlciA9ICIxMGhlYWRlciIKCXJldHVybiBzdHJpbmcuc3ViKGRhdGEsICNoZWFkZXIrMSkKZW5kCg==",
|
||||
Base64LuaCode: "CmZ1bmN0aW9uIGRfZ2VuKG1zZ190eXBlLCBkYXRhLCBjb3VudGVyKQoJbG9jYWwgaGVhZGVyID0gImhlYWRlciIKCXJldHVybiBjb3VudGVyIC4uIGhlYWRlciAuLiBkYXRhCmVuZAoKZnVuY3Rpb24gZF9wYXJzZShkYXRhKQoJbG9jYWwgaGVhZGVyID0gIjEwaGVhZGVyIgoJcmV0dXJuIHN0cmluZy5zdWIoZGF0YSwgI2hlYWRlcisxKQplbmQK",
|
||||
})
|
||||
return lua
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ func TestLua_Generate(t *testing.T) {
|
|||
t.Run("", func(t *testing.T) {
|
||||
l := newLua()
|
||||
defer l.Close()
|
||||
got, err := l.Generate([]byte("test"), 10)
|
||||
got, err := l.Generate(1, []byte("test"), 10)
|
||||
if err != nil {
|
||||
t.Errorf(
|
||||
"Lua.Generate() error = %v, wantErr %v",
|
||||
|
|
|
@ -433,6 +433,7 @@ func (device *Device) Close() {
|
|||
|
||||
if device.luaAdapter != nil {
|
||||
device.luaAdapter.Close()
|
||||
device.luaAdapter = nil
|
||||
}
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
|
@ -585,10 +586,10 @@ func (device *Device) isAdvancedSecurityOn() bool {
|
|||
|
||||
func (device *Device) resetProtocol() {
|
||||
// restore default message type values
|
||||
MessageInitiationType = 1
|
||||
MessageResponseType = 2
|
||||
MessageCookieReplyType = 3
|
||||
MessageTransportType = 4
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
}
|
||||
|
||||
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
||||
|
@ -721,7 +722,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default init type")
|
||||
MessageInitiationType = 1
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
}
|
||||
|
||||
if tempASecCfg.responsePacketMagicHeader > 4 {
|
||||
|
@ -731,7 +732,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default response type")
|
||||
MessageResponseType = 2
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
}
|
||||
|
||||
if tempASecCfg.underloadPacketMagicHeader > 4 {
|
||||
|
@ -741,7 +742,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default underload type")
|
||||
MessageCookieReplyType = 3
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
}
|
||||
|
||||
if tempASecCfg.transportPacketMagicHeader > 4 {
|
||||
|
@ -751,7 +752,7 @@ func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|||
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default transport type")
|
||||
MessageTransportType = 4
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
}
|
||||
|
||||
isSameMap := map[uint32]bool{}
|
||||
|
|
|
@ -52,11 +52,18 @@ const (
|
|||
WGLabelCookie = "cookie--"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMessageInitiationType uint32 = 1
|
||||
DefaultMessageResponseType uint32 = 2
|
||||
DefaultMessageCookieReplyType uint32 = 3
|
||||
DefaultMessageTransportType uint32 = 4
|
||||
)
|
||||
|
||||
var (
|
||||
MessageInitiationType uint32 = 1
|
||||
MessageResponseType uint32 = 2
|
||||
MessageCookieReplyType uint32 = 3
|
||||
MessageTransportType uint32 = 4
|
||||
MessageInitiationType uint32 = DefaultMessageInitiationType
|
||||
MessageResponseType uint32 = DefaultMessageResponseType
|
||||
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
|
||||
MessageTransportType uint32 = DefaultMessageTransportType
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -175,7 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
|
||||
if junkedHeader, err = peer.device.codecPacket(DefaultMessageInitiationType, junkedHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -237,7 +237,7 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
if junkedHeader, err = peer.device.codecPacket(junkedHeader); err != nil {
|
||||
if junkedHeader, err = peer.device.codecPacket(DefaultMessageResponseType, junkedHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -286,7 +286,7 @@ func (device *Device) SendHandshakeCookie(
|
|||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
packet := writer.Bytes()
|
||||
if packet, err = device.codecPacket(packet); err != nil {
|
||||
if packet, err = device.codecPacket(DefaultMessageCookieReplyType, packet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -547,10 +547,10 @@ func calculatePaddingSize(packetSize, mtu int) int {
|
|||
return paddedSize - lastUnit
|
||||
}
|
||||
|
||||
func (device *Device) codecPacket(packet []byte) ([]byte, error) {
|
||||
func (device *Device) codecPacket(msgType uint32, packet []byte) ([]byte, error) {
|
||||
if device.luaAdapter != nil {
|
||||
var err error
|
||||
packet, err = device.luaAdapter.Generate(packet, device.packetCounter.Add(1))
|
||||
packet, err = device.luaAdapter.Generate(int64(msgType),packet, device.packetCounter.Add(1))
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
|
||||
return nil, err
|
||||
|
@ -603,9 +603,9 @@ func (device *Device) RoutineEncryption(id int) {
|
|||
elem.packet,
|
||||
nil,
|
||||
)
|
||||
// TODO: check
|
||||
|
||||
var err error
|
||||
if elem.packet, err = device.codecPacket(elem.packet); err != nil {
|
||||
if elem.packet, err = device.codecPacket(DefaultMessageTransportType, elem.packet); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue