From fb3fa4f9158458654281129f44f354a65741aef3 Mon Sep 17 00:00:00 2001
From: Mathias Hall-Andersen <mathias@hall-andersen.dk>
Date: Thu, 27 Jul 2017 23:45:37 +0200
Subject: [PATCH] Improved timer code

---
 src/constants.go      |   1 +
 src/noise_protocol.go |   3 +-
 src/peer.go           |  20 +--
 src/receive.go        |  36 +++---
 src/send.go           |  76 ++++++-----
 src/timers.go         | 291 ++++++++++++++++++++++++------------------
 6 files changed, 240 insertions(+), 187 deletions(-)

diff --git a/src/constants.go b/src/constants.go
index 6b0d414..09d33d8 100644
--- a/src/constants.go
+++ b/src/constants.go
@@ -20,6 +20,7 @@ const (
 
 const (
 	RekeyAfterTimeReceiving = RekeyAfterTime - KeepaliveTimeout - RekeyTimeout
+	NewHandshakeTime        = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message
 )
 
 /* Implementation specific constants */
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index 5fe6fb2..e2ff573 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -37,6 +37,7 @@ const (
 	MessageCookieReplySize     = 64
 	MessageTransportHeaderSize = 16
 	MessageTransportSize       = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
+	MessageKeepaliveSize       = MessageTransportSize
 )
 
 const (
@@ -253,8 +254,6 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
 		}
 		hash = mixHash(hash, msg.Timestamp[:])
 
-		// TODO: check for flood attack
-
 		// check for replay attack
 
 		return timestamp.After(handshake.lastTimestamp)
diff --git a/src/peer.go b/src/peer.go
index 8eea929..9136959 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -40,21 +40,22 @@ type Peer struct {
 		stop               chan struct{} // (size 0) : close to stop all goroutines for peer
 	}
 	timer struct {
-		/* Both keep-alive timers acts as one (see timers.go)
-		 * They are kept seperate to simplify the implementation.
-		 */
 		keepalivePersistent *time.Timer // set for persistent keepalives
 		keepalivePassive    *time.Timer // set upon recieving messages
-		zeroAllKeys         *time.Timer // zero all key material after RejectAfterTime*3
+		newHandshake        *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout)
+		zeroAllKeys         *time.Timer // zero all key material (after RejectAfterTime*3)
+
+		pendingKeepalivePassive bool
+		pendingNewHandshake     bool
+		pendingZeroAllKeys      bool
+
+		needAnotherKeepalive bool
 	}
 	queue struct {
 		nonce    chan *QueueOutboundElement // nonce / pre-handshake queue
 		outbound chan *QueueOutboundElement // sequential ordering of work
 		inbound  chan *QueueInboundElement  // sequential ordering of work
 	}
-	flags struct {
-		keepaliveWaiting int32
-	}
 	mac MACStatePeer
 }
 
@@ -68,12 +69,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 	peer.mac.Init(pk)
 	peer.device = device
 
-	peer.timer.keepalivePassive = NewStoppedTimer()
 	peer.timer.keepalivePersistent = NewStoppedTimer()
+	peer.timer.keepalivePassive = NewStoppedTimer()
+	peer.timer.newHandshake = NewStoppedTimer()
 	peer.timer.zeroAllKeys = NewStoppedTimer()
 
-	peer.flags.keepaliveWaiting = AtomicFalse
-
 	// assign id for debugging
 
 	device.mutex.Lock()
diff --git a/src/receive.go b/src/receive.go
index d97ca41..c74211b 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -288,6 +288,7 @@ func (device *Device) RoutineHandshake() {
 	logDebug := device.log.Debug
 	logDebug.Println("Routine, handshake routine, started for device")
 
+	var temp [256]byte
 	var elem QueueHandshakeElement
 
 	for {
@@ -363,6 +364,7 @@ func (device *Device) RoutineHandshake() {
 					)
 					return
 				}
+				peer.TimerPacketReceived()
 
 				// update endpoint
 
@@ -378,17 +380,19 @@ func (device *Device) RoutineHandshake() {
 					return
 				}
 
+				peer.TimerEphemeralKeyCreated()
+
 				logDebug.Println("Creating response message for", peer.String())
 
-				outElem := device.NewOutboundElement()
-				writer := bytes.NewBuffer(outElem.buffer[:0])
+				writer := bytes.NewBuffer(temp[:0])
 				binary.Write(writer, binary.LittleEndian, response)
-				outElem.packet = writer.Bytes()
-				peer.mac.AddMacs(outElem.packet)
-				addToOutboundQueue(peer.queue.outbound, outElem)
+				packet := writer.Bytes()
+				peer.mac.AddMacs(packet)
 
-				// create new keypair
+				// send response
 
+				peer.SendBuffer(packet)
+				peer.TimerPacketSent()
 				peer.NewKeyPair()
 
 			case MessageResponseType:
@@ -418,12 +422,11 @@ func (device *Device) RoutineHandshake() {
 					)
 					return
 				}
-				kp := peer.NewKeyPair()
-				if kp == nil {
-					logDebug.Println("Failed to derieve key-pair")
-				}
+
+				peer.TimerPacketReceived()
+				peer.TimerHandshakeComplete()
+				peer.NewKeyPair()
 				peer.SendKeepAlive()
-				peer.EventHandshakeComplete()
 
 			default:
 				logError.Println("Invalid message type in handshake queue")
@@ -464,12 +467,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
 				return
 			}
 
-			// time (passive) keep-alive
-
-			peer.TimerStartKeepalive()
-
-			// refresh key material (rekey)
-
+			peer.TimerPacketReceived()
+			peer.TimerTransportReceived()
 			peer.KeepKeyFreshReceiving()
 
 			// check if using new key-pair
@@ -477,7 +476,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 			kp := &peer.keyPairs
 			kp.mutex.Lock()
 			if kp.next == elem.keyPair {
-				peer.EventHandshakeComplete()
+				peer.TimerHandshakeComplete()
 				kp.previous = kp.current
 				kp.current = kp.next
 				kp.next = nil
@@ -490,6 +489,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 				logDebug.Println("Received keep-alive from", peer.String())
 				return
 			}
+			peer.TimerDataReceived()
 
 			// verify source and strip padding
 
diff --git a/src/send.go b/src/send.go
index 7cdb806..37078b9 100644
--- a/src/send.go
+++ b/src/send.go
@@ -2,6 +2,7 @@ package main
 
 import (
 	"encoding/binary"
+	"errors"
 	"golang.org/x/crypto/chacha20poly1305"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv6"
@@ -51,6 +52,11 @@ func (peer *Peer) FlushNonceQueue() {
 	}
 }
 
+var (
+	ErrorNoEndpoint   = errors.New("No known endpoint for peer")
+	ErrorNoConnection = errors.New("No UDP socket for device")
+)
+
 func (device *Device) NewOutboundElement() *QueueOutboundElement {
 	return &QueueOutboundElement{
 		dropped: AtomicFalse,
@@ -103,6 +109,25 @@ func addToEncryptionQueue(
 	}
 }
 
+func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
+
+	peer.mutex.RLock()
+	endpoint := peer.endpoint
+	peer.mutex.RUnlock()
+	if endpoint == nil {
+		return 0, ErrorNoEndpoint
+	}
+
+	peer.device.net.mutex.RLock()
+	conn := peer.device.net.conn
+	peer.device.net.mutex.RUnlock()
+	if conn == nil {
+		return 0, ErrorNoConnection
+	}
+
+	return conn.WriteToUDP(buffer, endpoint)
+}
+
 /* Reads packets from the TUN and inserts
  * into nonce queue for peer
  *
@@ -349,42 +374,27 @@ func (peer *Peer) RoutineSequentialSender() {
 
 		case elem := <-peer.queue.outbound:
 			elem.mutex.Lock()
+			if elem.IsDropped() {
+				continue
+			}
 
-			func() {
-				if elem.IsDropped() {
-					return
-				}
-
-				// get endpoint and connection
-
-				peer.mutex.RLock()
-				endpoint := peer.endpoint
-				peer.mutex.RUnlock()
-				if endpoint == nil {
-					logDebug.Println("No endpoint for", peer.String())
-					return
-				}
-
-				device.net.mutex.RLock()
-				conn := device.net.conn
-				device.net.mutex.RUnlock()
-				if conn == nil {
-					logDebug.Println("No source for device")
-					return
-				}
-
-				// send message and refresh keys
-
-				_, err := conn.WriteToUDP(elem.packet, endpoint)
-				if err != nil {
-					return
-				}
-
-				atomic.AddUint64(&peer.stats.txBytes, uint64(len(elem.packet)))
-				peer.TimerResetKeepalive()
-			}()
+			// send message and return buffer to pool
 
+			length := uint64(len(elem.packet))
+			_, err := peer.SendBuffer(elem.packet)
 			device.PutMessageBuffer(elem.buffer)
+			if err != nil {
+				continue
+			}
+			atomic.AddUint64(&peer.stats.txBytes, length)
+
+			// update timers
+
+			peer.TimerPacketSent()
+			if len(elem.packet) != MessageKeepaliveSize {
+				peer.TimerDataSent()
+			}
+			peer.KeepKeyFreshSending()
 		}
 	}
 }
diff --git a/src/timers.go b/src/timers.go
index 2454414..5a16e9b 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -44,21 +44,6 @@ func (peer *Peer) KeepKeyFreshReceiving() {
 	}
 }
 
-/* Called after succesfully completing a handshake.
- * i.e. after:
- * - Valid handshake response
- * - First transport message under the "next" key
- */
-func (peer *Peer) EventHandshakeComplete() {
-	peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
-	peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
-	atomic.StoreInt64(
-		&peer.stats.lastHandshakeNano,
-		time.Now().UnixNano(),
-	)
-	signalSend(peer.signal.handshakeCompleted)
-}
-
 /* Queues a keep-alive if no packets are queued for peer
  */
 func (peer *Peer) SendKeepAlive() bool {
@@ -75,69 +60,89 @@ func (peer *Peer) SendKeepAlive() bool {
 	return true
 }
 
-/* Starts the "keep-alive" timer
- * (if not already running),
- * in response to incomming messages
+/* Authenticated data packet send
+ * Always called together with peer.EventPacketSend
+ *
+ * - Start new handshake timer
  */
-func (peer *Peer) TimerStartKeepalive() {
+func (peer *Peer) TimerDataSent() {
+	timerStop(peer.timer.keepalivePassive)
+	if !peer.timer.pendingNewHandshake {
+		peer.timer.pendingNewHandshake = true
+		peer.timer.newHandshake.Reset(NewHandshakeTime)
+	}
+}
 
-	// check if acknowledgement timer set yet
-
-	var waiting int32 = AtomicTrue
-	waiting = atomic.SwapInt32(&peer.flags.keepaliveWaiting, waiting)
-	if waiting == AtomicTrue {
+/* Event:
+ * Received non-empty (authenticated) transport message
+ *
+ * - Start passive keep-alive timer
+ */
+func (peer *Peer) TimerDataReceived() {
+	if peer.timer.pendingKeepalivePassive {
+		peer.timer.needAnotherKeepalive = true
 		return
 	}
+	peer.timer.pendingKeepalivePassive = false
+	peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
+}
 
-	// timer not yet set, start it
+/* Event:
+ * Any (authenticated) transport message received
+ * (keep-alive or data)
+ */
+func (peer *Peer) TimerTransportReceived() {
+	timerStop(peer.timer.newHandshake)
+}
 
-	wait := KeepaliveTimeout
+/* Event:
+ * Any packet send to the peer.
+ */
+func (peer *Peer) TimerPacketSent() {
 	interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
 	if interval > 0 {
 		duration := time.Duration(interval) * time.Second
-		if duration < wait {
-			wait = duration
-		}
+		peer.timer.keepalivePersistent.Reset(duration)
 	}
 }
 
-/* Resets both keep-alive timers
+/* Event:
+ * Any authenticated packet received from peer
  */
-func (peer *Peer) TimerResetKeepalive() {
-
-	// reset persistent timer
-
-	interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
-	if interval > 0 {
-		peer.timer.keepalivePersistent.Reset(
-			time.Duration(interval) * time.Second,
-		)
-	}
-
-	// stop acknowledgement timer
-
-	timerStop(peer.timer.keepalivePassive)
-	atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
+func (peer *Peer) TimerPacketReceived() {
+	peer.TimerPacketSent()
 }
 
-func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) {
+/* Called after succesfully completing a handshake.
+ * i.e. after:
+ *
+ * - Valid handshake response
+ * - First transport message under the "next" key
+ */
+func (peer *Peer) TimerHandshakeComplete() {
+	timerStop(peer.timer.zeroAllKeys)
+	atomic.StoreInt64(
+		&peer.stats.lastHandshakeNano,
+		time.Now().UnixNano(),
+	)
+	signalSend(peer.signal.handshakeCompleted)
+	peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
+}
 
-	// create initiation
-
-	elem := peer.device.NewOutboundElement()
-	msg, err := peer.device.CreateMessageInitiation(peer)
-	if err != nil {
-		return nil, err
+/* Called whenever an ephemeral key is generated
+ * i.e after:
+ *
+ * CreateMessageInitiation
+ * CreateMessageResponse
+ *
+ * Schedules the deletion of all key material
+ * upon failure to complete a handshake
+ */
+func (peer *Peer) TimerEphemeralKeyCreated() {
+	if !peer.timer.pendingZeroAllKeys {
+		peer.timer.pendingZeroAllKeys = true
+		peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
 	}
-
-	// marshal & schedule for sending
-
-	writer := bytes.NewBuffer(elem.buffer[:0])
-	binary.Write(writer, binary.LittleEndian, msg)
-	elem.packet = writer.Bytes()
-	peer.mac.AddMacs(elem.packet)
-	addToOutboundQueue(peer.queue.outbound, elem)
-	return elem, err
 }
 
 func (peer *Peer) RoutineTimerHandler() {
@@ -157,17 +162,30 @@ func (peer *Peer) RoutineTimerHandler() {
 
 		case <-peer.timer.keepalivePersistent.C:
 
-			logDebug.Println("Sending persistent keep-alive to", peer.String())
-
-			peer.SendKeepAlive()
-			peer.TimerResetKeepalive()
+			interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
+			if interval > 0 {
+				logDebug.Println("Sending persistent keep-alive to", peer.String())
+				peer.SendKeepAlive()
+			}
 
 		case <-peer.timer.keepalivePassive.C:
 
-			logDebug.Println("Sending passive persistent keep-alive to", peer.String())
+			logDebug.Println("Sending passive keep-alive to", peer.String())
 
 			peer.SendKeepAlive()
-			peer.TimerResetKeepalive()
+
+			if peer.timer.needAnotherKeepalive {
+				peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
+				peer.timer.needAnotherKeepalive = true
+			}
+
+		// unresponsive session
+
+		case <-peer.timer.newHandshake.C:
+
+			logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
+
+			signalSend(peer.signal.handshakeBegin)
 
 		// clear key material
 
@@ -175,13 +193,15 @@ func (peer *Peer) RoutineTimerHandler() {
 
 			logDebug.Println("Clearing all key material for", peer.String())
 
-			kp := &peer.keyPairs
-			kp.mutex.Lock()
-
 			hs := &peer.handshake
 			hs.mutex.Lock()
 
-			// unmap local indecies
+			kp := &peer.keyPairs
+			kp.mutex.Lock()
+
+			peer.timer.pendingZeroAllKeys = false
+
+			// unmap indecies
 
 			indices.mutex.Lock()
 			if kp.previous != nil {
@@ -224,80 +244,103 @@ func (peer *Peer) RoutineTimerHandler() {
 func (peer *Peer) RoutineHandshakeInitiator() {
 	device := peer.device
 
-	var elem *QueueOutboundElement
-
 	logInfo := device.log.Info
 	logError := device.log.Error
 	logDebug := device.log.Debug
 	logDebug.Println("Routine, handshake initator, started for", peer.String())
 
+	var temp [256]byte
+
 	for {
 
 		// wait for signal
 
 		select {
 		case <-peer.signal.handshakeBegin:
+			signalSend(peer.signal.handshakeBegin)
 		case <-peer.signal.stop:
 			return
 		}
 
 		// wait for handshake
 
-		func() {
-			var err error
-			var deadline time.Time
-			for attempts := uint(1); ; attempts++ {
+		deadline := time.Now().Add(MaxHandshakeAttemptTime)
 
-				// clear completed signal
+	Loop:
+		for attempts := uint(1); ; attempts++ {
 
-				select {
-				case <-peer.signal.handshakeCompleted:
-				case <-peer.signal.stop:
-					return
-				default:
-				}
+			// clear completed signal
 
-				// create initiation
-
-				if elem != nil {
-					elem.Drop()
-				}
-				elem, err = peer.BeginHandshakeInitiation()
-				if err != nil {
-					logError.Println("Failed to create initiation message", err, "for", peer.String())
-					return
-				}
-
-				// set timeout
-
-				if attempts == 1 {
-					deadline = time.Now().Add(MaxHandshakeAttemptTime)
-				}
-				timeout := time.NewTimer(RekeyTimeout)
-				logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String())
-
-				// wait for handshake or timeout
-
-				select {
-
-				case <-peer.signal.stop:
-					return
-
-				case <-peer.signal.handshakeCompleted:
-					<-timeout.C
-					return
-
-				case <-timeout.C:
-					if deadline.Before(time.Now().Add(RekeyTimeout)) {
-						logInfo.Println("Handshake negotiation timed out for", peer.String())
-						signalSend(peer.signal.flushNonceQueue)
-						timerStop(peer.timer.keepalivePersistent)
-						timerStop(peer.timer.keepalivePassive)
-						return
-					}
-				}
+			select {
+			case <-peer.signal.handshakeCompleted:
+			case <-peer.signal.stop:
+				return
+			default:
 			}
-		}()
+
+			// check if sufficient time for retry
+
+			if deadline.Before(time.Now().Add(RekeyTimeout)) {
+				logInfo.Println("Handshake negotiation timed out for", peer.String())
+				signalSend(peer.signal.flushNonceQueue)
+				timerStop(peer.timer.keepalivePersistent)
+				timerStop(peer.timer.keepalivePassive)
+				break Loop
+			}
+
+			// create initiation message
+
+			msg, err := peer.device.CreateMessageInitiation(peer)
+			if err != nil {
+				logError.Println("Failed to create handshake initiation message:", err)
+				break Loop
+			}
+			peer.TimerEphemeralKeyCreated()
+
+			// marshal and send
+
+			writer := bytes.NewBuffer(temp[:0])
+			binary.Write(writer, binary.LittleEndian, msg)
+			packet := writer.Bytes()
+			peer.mac.AddMacs(packet)
+			peer.TimerPacketSent()
+
+			_, err = peer.SendBuffer(packet)
+			if err != nil {
+				logError.Println(
+					"Failed to send handshake initiation message to",
+					peer.String(), ":", err,
+				)
+				continue
+			}
+
+			// set timeout
+
+			timeout := time.NewTimer(RekeyTimeout)
+			logDebug.Println(
+				"Handshake initiation attempt",
+				attempts, "sent to", peer.String(),
+			)
+
+			// wait for handshake or timeout
+
+			select {
+
+			case <-peer.signal.stop:
+				return
+
+			case <-peer.signal.handshakeCompleted:
+				<-timeout.C
+				break Loop
+
+			case <-timeout.C:
+				continue
+
+			}
+
+		}
+
+		// allow new signal to be set
 
 		signalClear(peer.signal.handshakeBegin)
 	}