From 93e3848ea76e755477bec8d9540a3c4c31ea7320 Mon Sep 17 00:00:00 2001
From: Mathias Hall-Andersen <mathias@hall-andersen.dk>
Date: Thu, 13 Jul 2017 14:32:40 +0200
Subject: [PATCH] Terminate on interface deletion

Program now terminates when the interface is removed
Increases the number of os threads (relevant for Go <1.5, not tested)
More consistent commenting
Improved logging (additional peer information)
---
 src/constants.go |  4 +--
 src/device.go    |  7 +++--
 src/ip.go        |  4 ---
 src/main.go      | 31 ++++++++++++++--------
 src/peer.go      | 19 ++++++++++---
 src/receive.go   | 24 +++++++++++------
 src/send.go      | 69 ++++++++++++++++++++++++++++--------------------
 src/timers.go    | 52 +++++++++++++++++-------------------
 src/trie.go      | 19 ++++++-------
 9 files changed, 132 insertions(+), 97 deletions(-)

diff --git a/src/constants.go b/src/constants.go
index 0384741..6b0d414 100644
--- a/src/constants.go
+++ b/src/constants.go
@@ -29,6 +29,6 @@ const (
 	QueueInboundSize       = 1024
 	QueueHandshakeSize     = 1024
 	QueueHandshakeBusySize = QueueHandshakeSize / 8
-	MinMessageSize         = MessageTransportSize // keep-alive
-	MaxMessageSize         = 4096                 // TODO: make depend on the MTU?
+	MinMessageSize         = MessageTransportSize // size of keep-alive
+	MaxMessageSize         = (1 << 16) - 1
 )
diff --git a/src/device.go b/src/device.go
index a26cc7b..b272544 100644
--- a/src/device.go
+++ b/src/device.go
@@ -98,9 +98,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
 	}
 
 	go device.RoutineBusyMonitor()
+	go device.RoutineWriteToTUN(tun)
 	go device.RoutineReadFromTUN(tun)
 	go device.RoutineReceiveIncomming()
-	go device.RoutineWriteToTUN(tun)
 	go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
 
 	return device
@@ -141,5 +141,8 @@ func (device *Device) RemoveAllPeers() {
 func (device *Device) Close() {
 	device.RemoveAllPeers()
 	close(device.signal.stop)
-	close(device.queue.encryption)
+}
+
+func (device *Device) Wait() {
+	<-device.signal.stop
 }
diff --git a/src/ip.go b/src/ip.go
index 36beb9c..752a404 100644
--- a/src/ip.go
+++ b/src/ip.go
@@ -5,17 +5,13 @@ import (
 )
 
 const (
-	IPv4version           = 4
 	IPv4offsetTotalLength = 2
 	IPv4offsetSrc         = 12
 	IPv4offsetDst         = IPv4offsetSrc + net.IPv4len
-	IPv4headerSize        = 20
 )
 
 const (
-	IPv6version             = 6
 	IPv6offsetPayloadLength = 4
 	IPv6offsetSrc           = 8
 	IPv6offsetDst           = IPv6offsetSrc + net.IPv6len
-	IPv6headerSize          = 40
 )
diff --git a/src/main.go b/src/main.go
index 50140e3..dc27472 100644
--- a/src/main.go
+++ b/src/main.go
@@ -5,6 +5,7 @@ import (
 	"log"
 	"net"
 	"os"
+	"runtime"
 )
 
 /* TODO: Fix logging
@@ -18,6 +19,10 @@ func main() {
 	}
 	deviceName := os.Args[1]
 
+	// increase number of go workers (for Go <1.5)
+
+	runtime.GOMAXPROCS(runtime.NumCPU())
+
 	// open TUN device
 
 	tun, err := CreateTUN(deviceName)
@@ -31,17 +36,21 @@ func main() {
 
 	// start configuration lister
 
-	socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
-	l, err := net.Listen("unix", socketPath)
-	if err != nil {
-		log.Fatal("listen error:", err)
-	}
-
-	for {
-		conn, err := l.Accept()
+	go func() {
+		socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
+		l, err := net.Listen("unix", socketPath)
 		if err != nil {
-			log.Fatal("accept error:", err)
+			log.Fatal("listen error:", err)
 		}
-		go ipcHandle(device, conn)
-	}
+
+		for {
+			conn, err := l.Accept()
+			if err != nil {
+				log.Fatal("accept error:", err)
+			}
+			go ipcHandle(device, conn)
+		}
+	}()
+
+	device.Wait()
 }
diff --git a/src/peer.go b/src/peer.go
index c8dc5c0..408c605 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -1,7 +1,9 @@
 package main
 
 import (
+	"encoding/base64"
 	"errors"
+	"fmt"
 	"net"
 	"sync"
 	"time"
@@ -38,9 +40,9 @@ type Peer struct {
 		/* Both keep-alive timers acts as one (see timers.go)
 		 * They are kept seperate to simplify the implementation.
 		 */
-		keepalivePersistent      *time.Timer // set for persistent keepalives
-		keepaliveAcknowledgement *time.Timer // set upon recieving messages
-		zeroAllKeys              *time.Timer // zero all key material after RejectAfterTime*3
+		keepalivePersistent *time.Timer // set for persistent keepalives
+		keepalivePassive    *time.Timer // set upon recieving messages
+		zeroAllKeys         *time.Timer // zero all key material after RejectAfterTime*3
 	}
 	queue struct {
 		nonce    chan *QueueOutboundElement // nonce / pre-handshake queue
@@ -63,8 +65,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 	peer.mac.Init(pk)
 	peer.device = device
 
+	peer.timer.keepalivePassive = NewStoppedTimer()
 	peer.timer.keepalivePersistent = NewStoppedTimer()
-	peer.timer.keepaliveAcknowledgement = NewStoppedTimer()
 	peer.timer.zeroAllKeys = NewStoppedTimer()
 
 	peer.flags.keepaliveWaiting = AtomicFalse
@@ -115,6 +117,15 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 	return peer
 }
 
+func (peer *Peer) String() string {
+	return fmt.Sprintf(
+		"peer(%d %s %s)",
+		peer.id,
+		peer.endpoint.String(),
+		base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
+	)
+}
+
 func (peer *Peer) Close() {
 	close(peer.signal.stop)
 }
diff --git a/src/receive.go b/src/receive.go
index 99089a9..3e649b6 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -4,6 +4,8 @@ import (
 	"bytes"
 	"encoding/binary"
 	"golang.org/x/crypto/chacha20poly1305"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -362,7 +364,7 @@ func (device *Device) RoutineHandshake() {
 					return
 				}
 
-				logDebug.Println("Creating response...")
+				logDebug.Println("Creating response message for", peer.String())
 
 				outElem := device.NewOutboundElement()
 				writer := bytes.NewBuffer(outElem.data[:0])
@@ -416,6 +418,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
 	var elem *QueueInboundElement
 
 	device := peer.device
+
+	logInfo := device.log.Info
 	logDebug := device.log.Debug
 	logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
 
@@ -450,7 +454,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
 			peer.KeepKeyFreshReceiving()
 
-			// check if confirming handshake
+			// check if using new key-pair
 
 			kp := &peer.keyPairs
 			kp.mutex.Lock()
@@ -465,17 +469,18 @@ func (peer *Peer) RoutineSequentialReceiver() {
 			// check for keep-alive
 
 			if len(elem.packet) == 0 {
+				logDebug.Println("Received keep-alive from", peer.String())
 				return
 			}
 
 			// verify source and strip padding
 
 			switch elem.packet[0] >> 4 {
-			case IPv4version:
+			case ipv4.Version:
 
 				// strip padding
 
-				if len(elem.packet) < IPv4headerSize {
+				if len(elem.packet) < ipv4.HeaderLen {
 					return
 				}
 
@@ -487,31 +492,33 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
 				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
 				if device.routingTable.LookupIPv4(dst) != peer {
+					logInfo.Println("Packet with unallowed source IP from", peer.String())
 					return
 				}
 
-			case IPv6version:
+			case ipv6.Version:
 
 				// strip padding
 
-				if len(elem.packet) < IPv6headerSize {
+				if len(elem.packet) < ipv6.HeaderLen {
 					return
 				}
 
 				field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
 				length := binary.BigEndian.Uint16(field)
-				length += IPv6headerSize
+				length += ipv6.HeaderLen
 				elem.packet = elem.packet[:length]
 
 				// verify IPv6 source
 
 				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
 				if device.routingTable.LookupIPv6(dst) != peer {
+					logInfo.Println("Packet with unallowed source IP from", peer.String())
 					return
 				}
 
 			default:
-				logDebug.Println("Receieved packet with unknown IP version")
+				logInfo.Println("Packet with invalid IP version from", peer.String())
 				return
 			}
 
@@ -522,6 +529,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
 }
 
 func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
+
 	logError := device.log.Error
 	logDebug := device.log.Debug
 	logDebug.Println("Routine, sequential tun writer, started")
diff --git a/src/send.go b/src/send.go
index 5ea9a8f..d8ddc82 100644
--- a/src/send.go
+++ b/src/send.go
@@ -3,6 +3,8 @@ package main
 import (
 	"encoding/binary"
 	"golang.org/x/crypto/chacha20poly1305"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -21,28 +23,26 @@ import (
  * The functions in this file occure (roughly) in the order packets are processed.
  */
 
-/* A work unit
- *
- * The sequential consumers will attempt to take the lock,
- * workers release lock when they have completed work on the packet.
+/* The sequential consumers will attempt to take the lock,
+ * workers release lock when they have completed work (encryption) on the packet.
  *
  * If the element is inserted into the "encryption queue",
- * the content is preceeded by enough "junk" to contain the header
+ * the content is preceeded by enough "junk" to contain the transport header
  * (to allow the construction of transport messages in-place)
  */
 type QueueOutboundElement struct {
 	dropped int32
 	mutex   sync.Mutex
-	data    [MaxMessageSize]byte
-	packet  []byte   // slice of "data" (always!)
-	nonce   uint64   // nonce for encryption
-	keyPair *KeyPair // key-pair for encryption
-	peer    *Peer    // related peer
+	data    [MaxMessageSize]byte // slice holding the packet data
+	packet  []byte               // slice of "data" (always!)
+	nonce   uint64               // nonce for encryption
+	keyPair *KeyPair             // key-pair for encryption
+	peer    *Peer                // related peer
 }
 
 func (peer *Peer) FlushNonceQueue() {
 	elems := len(peer.queue.nonce)
-	for i := 0; i < elems; i += 1 {
+	for i := 0; i < elems; i++ {
 		select {
 		case <-peer.queue.nonce:
 		default:
@@ -111,14 +111,18 @@ func addToEncryptionQueue(
  * Obs. Single instance per TUN device
  */
 func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+
 	if tun == nil {
-		// dummy
 		return
 	}
 
 	elem := device.NewOutboundElement()
 
-	device.log.Debug.Println("Routine, TUN Reader: started")
+	logDebug := device.log.Debug
+	logError := device.log.Error
+
+	logDebug.Println("Routine, TUN Reader: started")
+
 	for {
 		// read packet
 
@@ -129,12 +133,17 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
 		elem.packet = elem.data[MessageTransportHeaderSize:]
 		size, err := tun.Read(elem.packet)
 		if err != nil {
-			device.log.Error.Println("Failed to read packet from TUN device:", err)
-			continue
+
+			// stop process
+
+			logError.Println("Failed to read packet from TUN device:", err)
+			device.Close()
+			return
 		}
+
 		elem.packet = elem.packet[:size]
-		if len(elem.packet) < IPv4headerSize {
-			device.log.Error.Println("Packet too short, length:", size)
+		if len(elem.packet) < ipv4.HeaderLen {
+			logError.Println("Packet too short, length:", size)
 			continue
 		}
 
@@ -142,23 +151,24 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
 
 		var peer *Peer
 		switch elem.packet[0] >> 4 {
-		case IPv4version:
+		case ipv4.Version:
 			dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
 			peer = device.routingTable.LookupIPv4(dst)
 
-		case IPv6version:
+		case ipv6.Version:
 			dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
 			peer = device.routingTable.LookupIPv6(dst)
 
 		default:
-			device.log.Debug.Println("Receieved packet with unknown IP version")
+			logDebug.Println("Receieved packet with unknown IP version")
 		}
 
 		if peer == nil {
 			continue
 		}
+
 		if peer.endpoint == nil {
-			device.log.Debug.Println("No known endpoint for peer", peer.id)
+			logDebug.Println("No known endpoint for peer", peer.String())
 			continue
 		}
 
@@ -184,7 +194,7 @@ func (peer *Peer) RoutineNonce() {
 
 	device := peer.device
 	logDebug := device.log.Debug
-	logDebug.Println("Routine, nonce worker, started for peer", peer.id)
+	logDebug.Println("Routine, nonce worker, started for peer", peer.String())
 
 	func() {
 
@@ -216,15 +226,15 @@ func (peer *Peer) RoutineNonce() {
 					}
 				}
 				signalSend(peer.signal.handshakeBegin)
-				logDebug.Println("Waiting for key-pair, peer", peer.id)
+				logDebug.Println("Awaiting key-pair for", peer.String())
 
 				select {
 				case <-peer.signal.newKeyPair:
-					logDebug.Println("Key-pair negotiated for peer", peer.id)
+					logDebug.Println("Key-pair negotiated for", peer.String())
 					goto NextPacket
 
 				case <-peer.signal.flushNonceQueue:
-					logDebug.Println("Clearing queue for peer", peer.id)
+					logDebug.Println("Clearing queue for", peer.String())
 					peer.FlushNonceQueue()
 					elem = nil
 					goto NextPacket
@@ -313,13 +323,14 @@ func (peer *Peer) RoutineSequentialSender() {
 	device := peer.device
 
 	logDebug := device.log.Debug
-	logDebug.Println("Routine, sequential sender, started for peer", peer.id)
+	logDebug.Println("Routine, sequential sender, started for", peer.String())
 
 	for {
 		select {
 		case <-peer.signal.stop:
-			logDebug.Println("Routine, sequential sender, stopped for peer", peer.id)
+			logDebug.Println("Routine, sequential sender, stopped for", peer.String())
 			return
+
 		case work := <-peer.queue.outbound:
 			work.mutex.Lock()
 			if work.IsDropped() {
@@ -334,7 +345,7 @@ func (peer *Peer) RoutineSequentialSender() {
 				defer peer.mutex.RUnlock()
 
 				if peer.endpoint == nil {
-					logDebug.Println("No endpoint for peer:", peer.id)
+					logDebug.Println("No endpoint for", peer.String())
 					return
 				}
 
@@ -352,7 +363,7 @@ func (peer *Peer) RoutineSequentialSender() {
 				}
 				atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
 
-				// reset keep-alive (passive keep-alives / acknowledgements)
+				// reset keep-alive
 
 				peer.TimerResetKeepalive()
 			}()
diff --git a/src/timers.go b/src/timers.go
index 6393955..2e5046e 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -50,7 +50,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
  * - First transport message under the "next" key
  */
 func (peer *Peer) EventHandshakeComplete() {
-	peer.device.log.Debug.Println("Handshake completed")
+	peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
 	peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
 	signalSend(peer.signal.handshakeCompleted)
 }
@@ -112,7 +112,7 @@ func (peer *Peer) TimerResetKeepalive() {
 
 	// stop acknowledgement timer
 
-	timerStop(peer.timer.keepaliveAcknowledgement)
+	timerStop(peer.timer.keepalivePassive)
 	atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
 }
 
@@ -140,7 +140,7 @@ func (peer *Peer) RoutineTimerHandler() {
 	device := peer.device
 
 	logDebug := device.log.Debug
-	logDebug.Println("Routine, timer handler, started for peer", peer.id)
+	logDebug.Println("Routine, timer handler, started for peer", peer.String())
 
 	for {
 		select {
@@ -152,14 +152,14 @@ func (peer *Peer) RoutineTimerHandler() {
 
 		case <-peer.timer.keepalivePersistent.C:
 
-			logDebug.Println("Sending persistent keep-alive to peer", peer.id)
+			logDebug.Println("Sending persistent keep-alive to", peer.String())
 
 			peer.SendKeepAlive()
 			peer.TimerResetKeepalive()
 
-		case <-peer.timer.keepaliveAcknowledgement.C:
+		case <-peer.timer.keepalivePassive.C:
 
-			logDebug.Println("Sending passive persistent keep-alive to peer", peer.id)
+			logDebug.Println("Sending passive persistent keep-alive to", peer.String())
 
 			peer.SendKeepAlive()
 			peer.TimerResetKeepalive()
@@ -168,7 +168,7 @@ func (peer *Peer) RoutineTimerHandler() {
 
 		case <-peer.timer.zeroAllKeys.C:
 
-			logDebug.Println("Clearing all key material for peer", peer.id)
+			logDebug.Println("Clearing all key material for", peer.String())
 
 			// zero out key pairs
 
@@ -208,14 +208,12 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 
 	var elem *QueueOutboundElement
 
+	logInfo := device.log.Info
 	logError := device.log.Error
 	logDebug := device.log.Debug
-	logDebug.Println("Routine, handshake initator, started for peer", peer.id)
+	logDebug.Println("Routine, handshake initator, started for", peer.String())
 
-	for run := true; run; {
-		var err error
-		var attempts uint
-		var deadline time.Time
+	for {
 
 		// wait for signal
 
@@ -227,15 +225,17 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 
 		// wait for handshake
 
-		run = func() bool {
-			for {
+		func() {
+			var err error
+			var deadline time.Time
+			for attempts := uint(1); ; attempts++ {
 
 				// clear completed signal
 
 				select {
 				case <-peer.signal.handshakeCompleted:
 				case <-peer.signal.stop:
-					return false
+					return
 				default:
 				}
 
@@ -246,43 +246,39 @@ func (peer *Peer) RoutineHandshakeInitiator() {
 				}
 				elem, err = peer.BeginHandshakeInitiation()
 				if err != nil {
-					logError.Println("Failed to create initiation message:", err)
-					break
+					logError.Println("Failed to create initiation message", err, "for", peer.String())
+					return
 				}
 
 				// set timeout
 
-				attempts += 1
 				if attempts == 1 {
 					deadline = time.Now().Add(MaxHandshakeAttemptTime)
 				}
 				timeout := time.NewTimer(RekeyTimeout)
-				logDebug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
+				logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String())
 
 				// wait for handshake or timeout
 
 				select {
+
 				case <-peer.signal.stop:
-					return true
+					return
 
 				case <-peer.signal.handshakeCompleted:
 					<-timeout.C
-					return true
+					return
 
 				case <-timeout.C:
-					logDebug.Println("Timeout")
-
-					// check if sufficient time for retry
-
 					if deadline.Before(time.Now().Add(RekeyTimeout)) {
+						logInfo.Println("Handshake negotiation timed out for", peer.String())
 						signalSend(peer.signal.flushNonceQueue)
 						timerStop(peer.timer.keepalivePersistent)
-						timerStop(peer.timer.keepaliveAcknowledgement)
-						return true
+						timerStop(peer.timer.keepalivePassive)
+						return
 					}
 				}
 			}
-			return true
 		}()
 
 		signalClear(peer.signal.handshakeBegin)
diff --git a/src/trie.go b/src/trie.go
index c2304b2..e81b5b6 100644
--- a/src/trie.go
+++ b/src/trie.go
@@ -23,7 +23,8 @@ type Trie struct {
 	bits  []byte
 	peer  *Peer
 
-	// Index of "branching" bit
+	// index of "branching" bit
+
 	bit_at_byte  uint
 	bit_at_shift uint
 }
@@ -36,7 +37,7 @@ type Trie struct {
 func commonBits(ip1 net.IP, ip2 net.IP) uint {
 	var i uint
 	size := uint(len(ip1))
-	for i = 0; i < size; i += 1 {
+	for i = 0; i < size; i++ {
 		v := ip1[i] ^ ip2[i]
 		if v != 0 {
 			v >>= 1
@@ -84,7 +85,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
 		return node
 	}
 
-	// Walk recursivly
+	// walk recursivly
 
 	node.child[0] = node.child[0].RemovePeer(p)
 	node.child[1] = node.child[1].RemovePeer(p)
@@ -93,7 +94,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
 		return node
 	}
 
-	// Remove peer & merge
+	// remove peer & merge
 
 	node.peer = nil
 	if node.child[0] == nil {
@@ -108,7 +109,7 @@ func (node *Trie) choose(ip net.IP) byte {
 
 func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 
-	// At leaf
+	// at leaf
 
 	if node == nil {
 		return &Trie{
@@ -120,7 +121,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 		}
 	}
 
-	// Traverse deeper
+	// traverse deeper
 
 	common := commonBits(node.bits, ip)
 	if node.cidr <= cidr && common >= node.cidr {
@@ -133,7 +134,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 		return node
 	}
 
-	// Split node
+	// split node
 
 	newNode := &Trie{
 		bits:         ip,
@@ -145,7 +146,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 
 	cidr = min(cidr, common)
 
-	// Check for shorter prefix
+	// check for shorter prefix
 
 	if newNode.cidr == cidr {
 		bit := newNode.choose(node.bits)
@@ -153,7 +154,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 		return newNode
 	}
 
-	// Create new parent for node & newNode
+	// create new parent for node & newNode
 
 	parent := &Trie{
 		bits:         ip,