add remaning guarding

Signed-off-by: Mark Puha <marko10@inf.elte.hu>
This commit is contained in:
Mark Puha 2023-09-12 06:14:19 +02:00
parent 6035a275d2
commit b9d4759cef
2 changed files with 28 additions and 11 deletions

View file

@ -199,10 +199,12 @@ func (device *Device) CreateMessageInitiation(
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
device.aSecMux.RLock()
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
} }
device.aSecMux.RUnlock()
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
@ -261,9 +263,12 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte chainKey [blake2s.Size]byte
) )
device.aSecMux.RLock()
if msg.Type != MessageInitiationType { if msg.Type != MessageInitiationType {
device.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock()
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -392,7 +397,9 @@ func (device *Device) CreateMessageResponse(
} }
var msg MessageResponse var msg MessageResponse
device.aSecMux.RLock()
msg.Type = MessageResponseType msg.Type = MessageResponseType
device.aSecMux.RUnlock()
msg.Sender = handshake.localIndex msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex msg.Receiver = handshake.remoteIndex
@ -442,9 +449,12 @@ func (device *Device) CreateMessageResponse(
} }
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.aSecMux.RLock()
if msg.Type != MessageResponseType { if msg.Type != MessageResponseType {
device.aSecMux.RUnlock()
return nil return nil
} }
device.aSecMux.RUnlock()
// lookup handshake by receiver // lookup handshake by receiver

View file

@ -128,8 +128,10 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
// so only packet processed for cookie generation // so only packet processed for cookie generation
var junkedHeader []byte var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() { if peer.device.isAdvancedSecurityOn() {
peer.device.aSecMux.RLock()
err = peer.sendJunkPackets() err = peer.sendJunkPackets()
if err != nil { if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
return err return err
} }
@ -138,11 +140,13 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize) err = appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
if err != nil { if err != nil {
peer.device.aSecMux.RUnlock()
peer.device.log.Errorf("%v - %v", peer, err) peer.device.log.Errorf("%v - %v", peer, err)
return err return err
} }
junkedHeader = writer.Bytes() junkedHeader = writer.Bytes()
} }
peer.device.aSecMux.RUnlock()
} }
var buf [MessageInitiationSize]byte var buf [MessageInitiationSize]byte
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
@ -184,17 +188,20 @@ func (peer *Peer) SendHandshakeResponse() error {
return err return err
} }
var junkedHeader []byte var junkedHeader []byte
if peer.device.isAdvancedSecurityOn() && if peer.device.isAdvancedSecurityOn() {
peer.device.aSecCfg.responsePacketJunkSize != 0 { peer.device.aSecMux.RLock()
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize) buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])
err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize) err = appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
if err != nil { if err != nil {
peer.device.log.Errorf("%v - %v", peer, err) peer.device.aSecMux.RUnlock()
return err peer.device.log.Errorf("%v - %v", peer, err)
} return err
junkedHeader = writer.Bytes() }
junkedHeader = writer.Bytes()
}
peer.device.aSecMux.RUnlock()
} }
var buf [MessageResponseSize]byte var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0]) writer := bytes.NewBuffer(buf[:0])