From cf3a5130d3aa53fc56c7c3194ee326d5a1d21970 Mon Sep 17 00:00:00 2001
From: Mathias Hall-Andersen <mathias@hall-andersen.dk>
Date: Sat, 24 Jun 2017 22:03:52 +0200
Subject: [PATCH] Completed noise handshake

---
 src/index.go          |  17 +++---
 src/keypair.go        |   8 +--
 src/noise_helpers.go  |  17 +-----
 src/noise_protocol.go | 125 +++++++++++++++++++++++++++++++++++++-----
 src/noise_test.go     |  68 ++++++++++++++++++++++-
 5 files changed, 191 insertions(+), 44 deletions(-)

diff --git a/src/index.go b/src/index.go
index 83a7e29..81f71e9 100644
--- a/src/index.go
+++ b/src/index.go
@@ -6,13 +6,15 @@ import (
 )
 
 /* Index=0 is reserved for unset indecies
+ *
+ * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
  *
  */
 
 type IndexTable struct {
 	mutex      sync.RWMutex
 	keypairs   map[uint32]*KeyPair
-	handshakes map[uint32]*Handshake
+	handshakes map[uint32]*Peer
 }
 
 func randUint32() (uint32, error) {
@@ -32,10 +34,10 @@ func (table *IndexTable) Init() {
 	table.mutex.Lock()
 	defer table.mutex.Unlock()
 	table.keypairs = make(map[uint32]*KeyPair)
-	table.handshakes = make(map[uint32]*Handshake)
+	table.handshakes = make(map[uint32]*Peer)
 }
 
-func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
+func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
 	table.mutex.Lock()
 	defer table.mutex.Unlock()
 	for {
@@ -60,11 +62,10 @@ func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
 			continue
 		}
 
-		// update the index
+		// clean old index
 
-		delete(table.handshakes, handshake.localIndex)
-		handshake.localIndex = id
-		table.handshakes[id] = handshake
+		delete(table.handshakes, peer.handshake.localIndex)
+		table.handshakes[id] = peer
 		return id, nil
 	}
 }
@@ -75,7 +76,7 @@ func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
 	return table.keypairs[id]
 }
 
-func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
+func (table *IndexTable) LookupHandshake(id uint32) *Peer {
 	table.mutex.RLock()
 	defer table.mutex.RUnlock()
 	return table.handshakes[id]
diff --git a/src/keypair.go b/src/keypair.go
index 22a8244..e434c74 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -5,8 +5,8 @@ import (
 )
 
 type KeyPair struct {
-	recieveKey   cipher.AEAD
-	recieveNonce NoiseNonce
-	sendKey      cipher.AEAD
-	sendNonce    NoiseNonce
+	recv      cipher.AEAD
+	recvNonce NoiseNonce
+	send      cipher.AEAD
+	sendNonce NoiseNonce
 }
diff --git a/src/noise_helpers.go b/src/noise_helpers.go
index eadbc07..e163ace 100644
--- a/src/noise_helpers.go
+++ b/src/noise_helpers.go
@@ -45,22 +45,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
 	return
 }
 
-/*
- *
- */
-
-func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
-	return KDF1(c[:], data)
-}
-
-func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
-	return blake2s.Sum256(append(h[:], data...))
-}
-
-/* Curve25519 wrappers
- *
- * TODO: Rethink this
- */
+/* curve25519 wrappers */
 
 func newPrivateKey() (sk NoisePrivateKey, err error) {
 	// clamping: https://cr.yp.to/ecdh.html
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index b9c8981..7f26cf1 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -9,9 +9,11 @@ import (
 )
 
 const (
-	HandshakeInitialCreated = iota
+	HandshakeReset = iota
+	HandshakeInitialCreated
 	HandshakeInitialConsumed
 	HandshakeResponseCreated
+	HandshakeResponseConsumed
 )
 
 const (
@@ -71,7 +73,6 @@ type Handshake struct {
 }
 
 var (
-	EmptyMessage   []byte
 	ZeroNonce      [chacha20poly1305.NonceSize]byte
 	InitalChainKey [blake2s.Size]byte
 	InitalHash     [blake2s.Size]byte
@@ -82,6 +83,14 @@ func init() {
 	InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
 }
 
+func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+	return KDF1(c[:], data)
+}
+
+func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+	return blake2s.Sum256(append(h[:], data...))
+}
+
 func (h *Handshake) addToHash(data []byte) {
 	h.hash = addToHash(h.hash, data)
 }
@@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) {
 	h.chainKey = addToChainKey(h.chainKey, data)
 }
 
-func (device *Device) Precompute(peer *Peer) {
-	h := &peer.handshake
-	h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
-}
-
 func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 	handshake := &peer.handshake
 	handshake.mutex.Lock()
@@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 
 	msg.Type = MessageInitalType
 	msg.Ephemeral = handshake.localEphemeral.publicKey()
-	msg.Sender, err = device.indices.NewIndex(handshake)
+	handshake.localIndex, err = device.indices.NewIndex(peer)
 
 	if err != nil {
 		return nil, err
 	}
 
+	msg.Sender = handshake.localIndex
 	handshake.addToChainKey(msg.Ephemeral[:])
 	handshake.addToHash(msg.Ephemeral[:])
 
-	// encrypt long-term "identity key"
+	// encrypt identity key
 
 	func() {
 		var key [chacha20poly1305.KeySize]byte
@@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
 	handshake.chainKey = chainKey
 	handshake.remoteIndex = msg.Sender
 	handshake.remoteEphemeral = msg.Ephemeral
+	handshake.lastTimestamp = timestamp
 	handshake.state = HandshakeInitialConsumed
 	return peer
 }
@@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 	// assign index
 
 	var err error
-	var msg MessageResponse
-	msg.Type = MessageResponseType
-	msg.Sender, err = device.indices.NewIndex(handshake)
-	msg.Reciever = handshake.remoteIndex
+	handshake.localIndex, err = device.indices.NewIndex(peer)
 	if err != nil {
 		return nil, err
 	}
 
+	var msg MessageResponse
+	msg.Type = MessageResponseType
+	msg.Sender = handshake.localIndex
+	msg.Reciever = handshake.remoteIndex
+
 	// create ephemeral key
 
 	handshake.localEphemeral, err = newPrivateKey()
@@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 		return nil, err
 	}
 	msg.Ephemeral = handshake.localEphemeral.publicKey()
+	handshake.addToHash(msg.Ephemeral[:])
 
 	func() {
 		ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
@@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 
 	func() {
 		aead, _ := chacha20poly1305.New(key[:])
-		aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
+		aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
 		handshake.addToHash(msg.Empty[:])
 	}()
 
+	handshake.state = HandshakeResponseCreated
 	return &msg, nil
 }
+
+func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
+	if msg.Type != MessageResponseType {
+		panic(errors.New("bug: invalid message type"))
+	}
+
+	// lookup handshake by reciever
+
+	peer := device.indices.LookupHandshake(msg.Reciever)
+	if peer == nil {
+		return nil
+	}
+	handshake := &peer.handshake
+	handshake.mutex.Lock()
+	defer handshake.mutex.Unlock()
+	if handshake.state != HandshakeInitialCreated {
+		return nil
+	}
+
+	// finish 3-way DH
+
+	hash := addToHash(handshake.hash, msg.Ephemeral[:])
+	chainKey := handshake.chainKey
+
+	func() {
+		ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+		chainKey = addToChainKey(chainKey, ss[:])
+		ss = device.privateKey.sharedSecret(msg.Ephemeral)
+		chainKey = addToChainKey(chainKey, ss[:])
+	}()
+
+	// add preshared key (psk)
+
+	var tau [blake2s.Size]byte
+	var key [chacha20poly1305.KeySize]byte
+	chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
+	hash = addToHash(hash, tau[:])
+
+	// authenticate
+
+	aead, _ := chacha20poly1305.New(key[:])
+	_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+	if err != nil {
+		return nil
+	}
+	hash = addToHash(hash, msg.Empty[:])
+
+	// update handshake state
+
+	handshake.hash = hash
+	handshake.chainKey = chainKey
+	handshake.remoteIndex = msg.Sender
+	handshake.state = HandshakeResponseConsumed
+
+	return peer
+}
+
+func (peer *Peer) NewKeyPair() *KeyPair {
+	handshake := &peer.handshake
+	handshake.mutex.Lock()
+	defer handshake.mutex.Unlock()
+
+	// derive keys
+
+	var sendKey [chacha20poly1305.KeySize]byte
+	var recvKey [chacha20poly1305.KeySize]byte
+
+	if handshake.state == HandshakeResponseConsumed {
+		sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
+	} else if handshake.state == HandshakeResponseCreated {
+		recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
+	} else {
+		return nil
+	}
+
+	// create AEAD instances
+
+	var keyPair KeyPair
+	keyPair.send, _ = chacha20poly1305.New(sendKey[:])
+	keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
+	keyPair.sendNonce = 0
+	keyPair.recvNonce = 0
+
+	peer.handshake.state = HandshakeReset
+
+	return &keyPair
+}
diff --git a/src/noise_test.go b/src/noise_test.go
index 8d6a0fa..ddabf8e 100644
--- a/src/noise_test.go
+++ b/src/noise_test.go
@@ -63,7 +63,9 @@ func TestNoiseHandshake(t *testing.T) {
 
 	/* simulate handshake */
 
-	// Initiation message
+	// initiation message
+
+	t.Log("exchange initiation message")
 
 	msg1, err := dev1.CreateMessageInitial(peer2)
 	assertNil(t, err)
@@ -88,6 +90,68 @@ func TestNoiseHandshake(t *testing.T) {
 		peer2.handshake.hash[:],
 	)
 
-	// Response message
+	// response message
 
+	t.Log("exchange response message")
+
+	msg2, err := dev2.CreateMessageResponse(peer1)
+	assertNil(t, err)
+
+	peer = dev1.ConsumeMessageResponse(msg2)
+	if peer == nil {
+		t.Fatal("handshake failed at response message")
+	}
+
+	assertEqual(
+		t,
+		peer1.handshake.chainKey[:],
+		peer2.handshake.chainKey[:],
+	)
+
+	assertEqual(
+		t,
+		peer1.handshake.hash[:],
+		peer2.handshake.hash[:],
+	)
+
+	// key pairs
+
+	t.Log("deriving keys")
+
+	key1 := peer1.NewKeyPair()
+	key2 := peer2.NewKeyPair()
+
+	if key1 == nil {
+		t.Fatal("failed to dervice key-pair for peer 1")
+	}
+
+	if key2 == nil {
+		t.Fatal("failed to dervice key-pair for peer 2")
+	}
+
+	// encrypting / decryption test
+
+	t.Log("test key pairs")
+
+	func() {
+		testMsg := []byte("wireguard test message 1")
+		var err error
+		var out []byte
+		var nonce [12]byte
+		out = key1.send.Seal(out, nonce[:], testMsg, nil)
+		out, err = key2.recv.Open(out[:0], nonce[:], out, nil)
+		assertNil(t, err)
+		assertEqual(t, out, testMsg)
+	}()
+
+	func() {
+		testMsg := []byte("wireguard test message 2")
+		var err error
+		var out []byte
+		var nonce [12]byte
+		out = key2.send.Seal(out, nonce[:], testMsg, nil)
+		out, err = key1.recv.Open(out[:0], nonce[:], out, nil)
+		assertNil(t, err)
+		assertEqual(t, out, testMsg)
+	}()
 }