From eb75ff430d1f78e129bbfe49d612f241ca418df4 Mon Sep 17 00:00:00 2001
From: Mathias Hall-Andersen <mathias@hall-andersen.dk>
Date: Mon, 26 Jun 2017 22:07:29 +0200
Subject: [PATCH] Begin implementation of outbound work queue

---
 src/device.go         |   3 +
 src/index.go          |  75 +++++++++++++++----------
 src/keypair.go        |  16 +++++-
 src/noise_protocol.go |  43 +++++++++++++--
 src/peer.go           |   4 +-
 src/send.go           | 124 ++++++++++++++++++++++++++----------------
 6 files changed, 181 insertions(+), 84 deletions(-)

diff --git a/src/device.go b/src/device.go
index ce10a63..4b8cda0 100644
--- a/src/device.go
+++ b/src/device.go
@@ -2,11 +2,14 @@ package main
 
 import (
 	"log"
+	"net"
 	"sync"
 )
 
 type Device struct {
 	mtu               int
+	source            *net.UDPAddr // UDP source address
+	conn              *net.UDPConn // UDP "connection"
 	mutex             sync.RWMutex
 	peers             map[NoisePublicKey]*Peer
 	indices           IndexTable
diff --git a/src/index.go b/src/index.go
index 81f71e9..9178510 100644
--- a/src/index.go
+++ b/src/index.go
@@ -11,10 +11,15 @@ import (
  *
  */
 
+type IndexTableEntry struct {
+	peer      *Peer
+	handshake *Handshake
+	keyPair   *KeyPair
+}
+
 type IndexTable struct {
-	mutex      sync.RWMutex
-	keypairs   map[uint32]*KeyPair
-	handshakes map[uint32]*Peer
+	mutex sync.RWMutex
+	table map[uint32]IndexTableEntry
 }
 
 func randUint32() (uint32, error) {
@@ -32,52 +37,66 @@ func randUint32() (uint32, error) {
 
 func (table *IndexTable) Init() {
 	table.mutex.Lock()
-	defer table.mutex.Unlock()
-	table.keypairs = make(map[uint32]*KeyPair)
-	table.handshakes = make(map[uint32]*Peer)
+	table.table = make(map[uint32]IndexTableEntry)
+	table.mutex.Unlock()
+}
+
+func (table *IndexTable) ClearIndex(index uint32) {
+	if index == 0 {
+		return
+	}
+	table.mutex.Lock()
+	delete(table.table, index)
+	table.mutex.Unlock()
+}
+
+func (table *IndexTable) Insert(key uint32, value IndexTableEntry) {
+	table.mutex.Lock()
+	table.table[key] = value
+	table.mutex.Unlock()
 }
 
 func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
-	table.mutex.Lock()
-	defer table.mutex.Unlock()
 	for {
 		// generate random index
 
-		id, err := randUint32()
+		index, err := randUint32()
 		if err != nil {
-			return id, err
+			return index, err
 		}
-		if id == 0 {
+		if index == 0 {
 			continue
 		}
 
 		// check if index used
 
-		_, ok := table.keypairs[id]
+		table.mutex.RLock()
+		_, ok := table.table[index]
 		if ok {
 			continue
 		}
-		_, ok = table.handshakes[id]
-		if ok {
+		table.mutex.RUnlock()
+
+		// replace index
+
+		table.mutex.Lock()
+		_, found := table.table[index]
+		if found {
+			table.mutex.Unlock()
 			continue
 		}
-
-		// clean old index
-
-		delete(table.handshakes, peer.handshake.localIndex)
-		table.handshakes[id] = peer
-		return id, nil
+		table.table[index] = IndexTableEntry{
+			peer:      peer,
+			handshake: &peer.handshake,
+			keyPair:   nil,
+		}
+		table.mutex.Unlock()
+		return index, nil
 	}
 }
 
-func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
+func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
 	table.mutex.RLock()
 	defer table.mutex.RUnlock()
-	return table.keypairs[id]
-}
-
-func (table *IndexTable) LookupHandshake(id uint32) *Peer {
-	table.mutex.RLock()
-	defer table.mutex.RUnlock()
-	return table.handshakes[id]
+	return table.table[id]
 }
diff --git a/src/keypair.go b/src/keypair.go
index e7961a8..53e123f 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -16,6 +16,18 @@ type KeyPairs struct {
 	mutex      sync.RWMutex
 	current    *KeyPair
 	previous   *KeyPair
-	next       *KeyPair
-	newKeyPair chan bool
+	next       *KeyPair  // not yet "confirmed by transport"
+	newKeyPair chan bool // signals when "current" has been updated
+}
+
+func (kp *KeyPairs) Init() {
+	kp.mutex.Lock()
+	kp.newKeyPair = make(chan bool, 5)
+	kp.mutex.Unlock()
+}
+
+func (kp *KeyPairs) Current() *KeyPair {
+	kp.mutex.RLock()
+	defer kp.mutex.RUnlock()
+	return kp.current
 }
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index a16908a..bf1db9b 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -120,13 +120,15 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
 		return nil, err
 	}
 
+	device.indices.ClearIndex(handshake.localIndex)
+	handshake.localIndex, err = device.indices.NewIndex(peer)
+
 	// assign index
 
 	var msg MessageInitiation
 
 	msg.Type = MessageInitiationType
 	msg.Ephemeral = handshake.localEphemeral.publicKey()
-	handshake.localIndex, err = device.indices.NewIndex(peer)
 
 	if err != nil {
 		return nil, err
@@ -249,6 +251,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 	// assign index
 
 	var err error
+	device.indices.ClearIndex(handshake.localIndex)
 	handshake.localIndex, err = device.indices.NewIndex(peer)
 	if err != nil {
 		return nil, err
@@ -299,11 +302,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 
 	// lookup handshake by reciever
 
-	peer := device.indices.LookupHandshake(msg.Reciever)
-	if peer == nil {
+	lookup := device.indices.Lookup(msg.Reciever)
+	handshake := lookup.handshake
+	if handshake == nil {
 		return nil
 	}
-	handshake := &peer.handshake
+
 	handshake.mutex.Lock()
 	defer handshake.mutex.Unlock()
 	if handshake.state != HandshakeInitiationCreated {
@@ -345,7 +349,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 	handshake.remoteIndex = msg.Sender
 	handshake.state = HandshakeResponseConsumed
 
-	return peer
+	return lookup.peer
 }
 
 func (peer *Peer) NewKeyPair() *KeyPair {
@@ -355,13 +359,16 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 
 	// derive keys
 
+	var isInitiator bool
 	var sendKey [chacha20poly1305.KeySize]byte
 	var recvKey [chacha20poly1305.KeySize]byte
 
 	if handshake.state == HandshakeResponseConsumed {
 		sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
+		isInitiator = true
 	} else if handshake.state == HandshakeResponseCreated {
 		recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
+		isInitiator = false
 	} else {
 		return nil
 	}
@@ -369,16 +376,40 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 	// create AEAD instances
 
 	var keyPair KeyPair
+
 	keyPair.send, _ = chacha20poly1305.New(sendKey[:])
 	keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
 	keyPair.sendNonce = 0
 	keyPair.recvNonce = 0
 
+	// remap index
+
+	peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
+		peer:      peer,
+		keyPair:   &keyPair,
+		handshake: nil,
+	})
+	handshake.localIndex = 0
+
+	// rotate key pairs
+
+	func() {
+		kp := &peer.keyPairs
+		kp.mutex.Lock()
+		defer kp.mutex.Unlock()
+		if isInitiator {
+			kp.previous = peer.keyPairs.current
+			kp.current = &keyPair
+			kp.newKeyPair <- true
+		} else {
+			kp.next = &keyPair
+		}
+	}()
+
 	// zero handshake
 
 	handshake.chainKey = [blake2s.Size]byte{}
 	handshake.localEphemeral = NoisePrivateKey{}
 	peer.handshake.state = HandshakeZeroed
-
 	return &keyPair
 }
diff --git a/src/peer.go b/src/peer.go
index 42b9e8d..6a879cb 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -14,8 +14,7 @@ const (
 
 type Peer struct {
 	mutex                       sync.RWMutex
-	endpointIP                  net.IP        //
-	endpointPort                uint16        //
+	endpoint                    *net.UDPAddr
 	persistentKeepaliveInterval time.Duration // 0 = disabled
 	keyPairs                    KeyPairs
 	handshake                   Handshake
@@ -35,6 +34,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 
 	peer.mutex.Lock()
 	peer.device = device
+	peer.keyPairs.Init()
 	peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
 
 	// map public key
diff --git a/src/send.go b/src/send.go
index 9790320..da5905d 100644
--- a/src/send.go
+++ b/src/send.go
@@ -1,9 +1,11 @@
 package main
 
 import (
+	"encoding/binary"
+	"golang.org/x/crypto/chacha20poly1305"
 	"net"
 	"sync"
-	"sync/atomic"
+	"time"
 )
 
 /* Handles outbound flow
@@ -70,85 +72,115 @@ func (device *Device) SendPacket(packet []byte) {
  *
  * TODO: avoid dynamic allocation of work queue elements
  */
-func (peer *Peer) ConsumeOutboundPackets() {
+func (peer *Peer) RoutineOutboundNonceWorker() {
+	var packet []byte
+	var keyPair *KeyPair
+	var flushTimer time.Timer
+
 	for {
-		// wait for key pair
-		keyPair := func() *KeyPair {
-			peer.keyPairs.mutex.RLock()
-			defer peer.keyPairs.mutex.RUnlock()
-			return peer.keyPairs.current
-		}()
-		if keyPair == nil {
-			if len(peer.queueOutboundRouting) > 0 {
-				// TODO: start handshake
-				<-peer.keyPairs.newKeyPair
-			}
-			continue
+
+		// wait for packet
+
+		if packet == nil {
+			packet = <-peer.queueOutboundRouting
 		}
 
-		// assign packets key pair
-		for {
+		// wait for key pair
+
+		for keyPair == nil {
+			flushTimer.Reset(time.Second * 10)
+			// TODO: Handshake or NOP
 			select {
 			case <-peer.keyPairs.newKeyPair:
-			default:
-			case <-peer.keyPairs.newKeyPair:
-			case packet := <-peer.queueOutboundRouting:
+				keyPair = peer.keyPairs.Current()
+				continue
+			case <-flushTimer.C:
+				size := len(peer.queueOutboundRouting)
+				for i := 0; i < size; i += 1 {
+					<-peer.queueOutboundRouting
+				}
+				packet = nil
+			}
+			break
+		}
 
-				// create new work element
+		// process current packet
 
-				work := new(OutboundWorkQueueElement)
-				work.wg.Add(1)
-				work.keyPair = keyPair
-				work.packet = packet
-				work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+		if packet != nil {
 
-				peer.queueOutbound <- work
+			// create work element
 
-				// drop packets until there is room
+			work := new(OutboundWorkQueueElement)
+			work.wg.Add(1)
+			work.keyPair = keyPair
+			work.packet = packet
+			work.nonce = keyPair.sendNonce
 
+			packet = nil
+			peer.queueOutbound <- work
+			keyPair.sendNonce += 1
+
+			// drop packets until there is space
+
+			func() {
 				for {
 					select {
 					case peer.device.queueWorkOutbound <- work:
-						break
+						return
 					default:
 						drop := <-peer.device.queueWorkOutbound
 						drop.packet = nil
 						drop.wg.Done()
 					}
 				}
-			}
+			}()
 		}
 	}
 }
 
+/* Go routine
+ *
+ * sequentially reads packets from queue and sends to endpoint
+ *
+ */
 func (peer *Peer) RoutineSequential() {
 	for work := range peer.queueOutbound {
 		work.wg.Wait()
+
+		// check if dropped ("ghost packet")
+
 		if work.packet == nil {
 			continue
 		}
+
+		//
+
 	}
 }
 
-func (device *Device) EncryptionWorker() {
-	for {
-		work := <-device.queueWorkOutbound
+func (device *Device) RoutineEncryptionWorker() {
+	var nonce [chacha20poly1305.NonceSize]byte
+	for work := range device.queueWorkOutbound {
+		// pad packet
 
-		func() {
-			defer work.wg.Done()
+		padding := device.mtu - len(work.packet)
+		if padding < 0 {
+			work.packet = nil
+			work.wg.Done()
+		}
+		for n := 0; n < padding; n += 1 {
+			work.packet = append(work.packet, 0)
+		}
 
-			// pad packet
-			padding := device.mtu - len(work.packet)
-			if padding < 0 {
-				work.packet = nil
-				return
-			}
-			for n := 0; n < padding; n += 1 {
-				work.packet = append(work.packet, 0) // TODO: gotta be a faster way
-			}
+		// encrypt
 
-			//
-
-		}()
+		binary.LittleEndian.PutUint64(nonce[4:], work.nonce)
+		work.packet = work.keyPair.send.Seal(
+			work.packet[:0],
+			nonce[:],
+			work.packet,
+			nil,
+		)
+		work.wg.Done()
 	}
 }