From 7e10ebe1010898c48e3f1cfc12ad42d0bb5c0fa1 Mon Sep 17 00:00:00 2001
From: "Jason A. Donenfeld" <Jason@zx2c4.com>
Date: Mon, 14 May 2018 00:28:30 +0200
Subject: [PATCH] Introduce rwcancel

---
 Makefile                  |   2 +-
 misc.go                   |   2 +-
 rwcancel/rwcancel_unix.go | 132 ++++++++++++++++++++++++++++++++++++++
 tun_linux.go              |  79 ++++-------------------
 uapi_linux.go             |  30 +++++++--
 5 files changed, 170 insertions(+), 75 deletions(-)
 create mode 100644 rwcancel/rwcancel_unix.go

diff --git a/Makefile b/Makefile
index 77eaac9..1513ef5 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
 all: wireguard-go
 
-wireguard-go: $(wildcard *.go)
+wireguard-go: $(wildcard *.go) $(wildcard */*.go)
 	go build -o $@
 
 clean:
diff --git a/misc.go b/misc.go
index f94a617..85a2e80 100644
--- a/misc.go
+++ b/misc.go
@@ -47,7 +47,7 @@ func toInt32(n uint32) int32 {
 	return int32(-(n & mask) + (n & ^mask))
 }
 
-func min(a uint, b uint) uint {
+func min(a, b uint) uint {
 	if a > b {
 		return b
 	}
diff --git a/rwcancel/rwcancel_unix.go b/rwcancel/rwcancel_unix.go
new file mode 100644
index 0000000..cd3661f
--- /dev/null
+++ b/rwcancel/rwcancel_unix.go
@@ -0,0 +1,132 @@
+/* SPDX-License-Identifier: GPL-2.0
+ *
+ * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ */
+
+package rwcancel
+
+import (
+	"errors"
+	"golang.org/x/sys/unix"
+	"os"
+	"runtime"
+	"syscall"
+)
+
+type RWCancel struct {
+	fd            int
+	closingReader *os.File
+	closingWriter *os.File
+}
+
+type fdSet struct {
+	fdset unix.FdSet
+}
+
+func (fdset *fdSet) set(i int) {
+	bits := 32 << (^uint(0) >> 63)
+	fdset.fdset.Bits[i/bits] |= 1 << uint(i%bits)
+}
+
+func (fdset *fdSet) check(i int) bool {
+	bits := 32 << (^uint(0) >> 63)
+	return (fdset.fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
+}
+
+func max(a, b int) int {
+	if a > b {
+		return a
+	}
+	return b
+}
+
+func NewRWCancel(fd int) (*RWCancel, error) {
+	err := unix.SetNonblock(fd, true)
+	if err != nil {
+		return nil, err
+	}
+	rwcancel := RWCancel{fd: fd}
+
+	rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe()
+	if err != nil {
+		return nil, err
+	}
+
+	runtime.SetFinalizer(&rwcancel, func(rw *RWCancel) {
+		rw.Cancel()
+	})
+
+	return &rwcancel, nil
+}
+
+/* https://golang.org/src/crypto/rand/eagain.go */
+func ErrorIsEAGAIN(err error) bool {
+	if pe, ok := err.(*os.PathError); ok {
+		if errno, ok := pe.Err.(syscall.Errno); ok && errno == syscall.EAGAIN {
+			return true
+		}
+	}
+	if errno, ok := err.(syscall.Errno); ok && errno == syscall.EAGAIN {
+		return true
+	}
+	return false
+}
+
+func (rw *RWCancel) ReadyRead() bool {
+	closeFd := int(rw.closingReader.Fd())
+	fdset := fdSet{}
+	fdset.set(rw.fd)
+	fdset.set(closeFd)
+	_, err := unix.Select(max(rw.fd, closeFd)+1, &fdset.fdset, nil, nil, nil)
+	if err != nil {
+		return false
+	}
+	if fdset.check(closeFd) {
+		return false
+	}
+	return fdset.check(rw.fd)
+}
+
+func (rw *RWCancel) ReadyWrite() bool {
+	closeFd := int(rw.closingReader.Fd())
+	fdset := fdSet{}
+	fdset.set(rw.fd)
+	fdset.set(closeFd)
+	_, err := unix.Select(max(rw.fd, closeFd)+1, nil, &fdset.fdset, nil, nil)
+	if err != nil {
+		return false
+	}
+	if fdset.check(closeFd) {
+		return false
+	}
+	return fdset.check(rw.fd)
+}
+
+func (rw *RWCancel) Read(p []byte) (n int, err error) {
+	for {
+		n, err := unix.Read(rw.fd, p)
+		if err == nil || !ErrorIsEAGAIN(err) {
+			return n, err
+		}
+		if !rw.ReadyRead() {
+			return 0, errors.New("fd closed")
+		}
+	}
+}
+
+func (rw *RWCancel) Write(p []byte) (n int, err error) {
+	for {
+		n, err := unix.Write(rw.fd, p)
+		if err == nil || !ErrorIsEAGAIN(err) {
+			return n, err
+		}
+		if !rw.ReadyWrite() {
+			return 0, errors.New("fd closed")
+		}
+	}
+}
+
+func (rw *RWCancel) Cancel() (err error) {
+	_, err = rw.closingWriter.Write([]byte{0})
+	return
+}
diff --git a/tun_linux.go b/tun_linux.go
index 9f60d2b..3510f94 100644
--- a/tun_linux.go
+++ b/tun_linux.go
@@ -11,6 +11,7 @@ package main
  */
 
 import (
+	"./rwcancel"
 	"bytes"
 	"encoding/binary"
 	"errors"
@@ -20,7 +21,6 @@ import (
 	"net"
 	"os"
 	"strconv"
-	"syscall"
 	"time"
 	"unsafe"
 )
@@ -31,14 +31,13 @@ const (
 )
 
 type NativeTun struct {
-	fd            *os.File
-	index         int32         // if index
-	name          string        // name of interface
-	errors        chan error    // async error handling
-	events        chan TUNEvent // device related events
-	nopi          bool          // the device was pased IFF_NO_PI
-	closingReader *os.File
-	closingWriter *os.File
+	fd       *os.File
+	index    int32         // if index
+	name     string        // name of interface
+	errors   chan error    // async error handling
+	events   chan TUNEvent // device related events
+	nopi     bool          // the device was pased IFF_NO_PI
+	rwcancel *rwcancel.RWCancel
 }
 
 func (tun *NativeTun) File() *os.File {
@@ -305,43 +304,6 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
 	return tun.fd.Write(buff)
 }
 
-type FdSet struct {
-	fdset unix.FdSet
-}
-
-func (fdset *FdSet) set(i int) {
-	bits := 32 << (^uint(0) >> 63)
-	fdset.fdset.Bits[i/bits] |= 1 << uint(i%bits)
-}
-
-func (fdset *FdSet) check(i int) bool {
-	bits := 32 << (^uint(0) >> 63)
-	return (fdset.fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
-}
-
-func max(a, b int) int {
-	if a > b {
-		return a
-	}
-	return b
-}
-
-func (tun *NativeTun) readyRead() bool {
-	readFd := int(tun.fd.Fd())
-	closeFd := int(tun.closingReader.Fd())
-	fdset := FdSet{}
-	fdset.set(readFd)
-	fdset.set(closeFd)
-	_, err := unix.Select(max(readFd, closeFd)+1, &fdset.fdset, nil, nil, nil)
-	if err != nil {
-		return false
-	}
-	if fdset.check(closeFd) {
-		return false
-	}
-	return fdset.check(readFd)
-}
-
 func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
 	select {
 	case err := <-tun.errors:
@@ -360,24 +322,14 @@ func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
 	}
 }
 
-/* https://golang.org/src/crypto/rand/eagain.go */
-func unixIsEAGAIN(err error) bool {
-	if pe, ok := err.(*os.PathError); ok {
-		if errno, ok := pe.Err.(syscall.Errno); ok && errno == syscall.EAGAIN {
-			return true
-		}
-	}
-	return false
-}
-
 func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
 	for {
 		n, err := tun.doRead(buff, offset)
-		if err == nil || !unixIsEAGAIN(err) {
+		if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
 			return n, err
 		}
-		if !tun.readyRead() {
-			return 0, errors.New("Tun device closed")
+		if !tun.rwcancel.ReadyRead() {
+			return 0, errors.New("tun device closed")
 		}
 	}
 }
@@ -391,7 +343,7 @@ func (tun *NativeTun) Close() error {
 	if err != nil {
 		return err
 	}
-	tun.closingWriter.Write([]byte{0})
+	tun.rwcancel.Cancel()
 	close(tun.events)
 	return nil
 }
@@ -450,7 +402,7 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
 	}
 	var err error
 
-	err = unix.SetNonblock(int(fd.Fd()), true)
+	device.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd()))
 	if err != nil {
 		return nil, err
 	}
@@ -460,11 +412,6 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
 		return nil, err
 	}
 
-	device.closingReader, device.closingWriter, err = os.Pipe()
-	if err != nil {
-		return nil, err
-	}
-
 	// start event listener
 
 	device.index, err = getIFIndex(device.name)
diff --git a/uapi_linux.go b/uapi_linux.go
index c40472e..67024e9 100644
--- a/uapi_linux.go
+++ b/uapi_linux.go
@@ -6,6 +6,7 @@
 package main
 
 import (
+	"./rwcancel"
 	"errors"
 	"fmt"
 	"golang.org/x/sys/unix"
@@ -24,10 +25,11 @@ const (
 )
 
 type UAPIListener struct {
-	listener  net.Listener // unix socket listener
-	connNew   chan net.Conn
-	connErr   chan error
-	inotifyFd int
+	listener        net.Listener // unix socket listener
+	connNew         chan net.Conn
+	connErr         chan error
+	inotifyFd       int
+	inotifyRWCancel *rwcancel.RWCancel
 }
 
 func (l *UAPIListener) Accept() (net.Conn, error) {
@@ -45,10 +47,14 @@ func (l *UAPIListener) Accept() (net.Conn, error) {
 func (l *UAPIListener) Close() error {
 	err1 := unix.Close(l.inotifyFd)
 	err2 := l.listener.Close()
+	err3 := l.inotifyRWCancel.Cancel()
 	if err1 != nil {
 		return err1
 	}
-	return err2
+	if err2 != nil {
+		return err2
+	}
+	return err3
 }
 
 func (l *UAPIListener) Addr() net.Addr {
@@ -94,15 +100,25 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
 		return nil, err
 	}
 
+	uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd)
+	if err != nil {
+		unix.Close(uapi.inotifyFd)
+		return nil, err
+	}
+
 	go func(l *UAPIListener) {
-		var buff [4096]byte
+		var buff [0]byte
 		for {
 			// start with lstat to avoid race condition
 			if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
 				l.connErr <- err
 				return
 			}
-			unix.Read(uapi.inotifyFd, buff[:])
+			_, err := uapi.inotifyRWCancel.Read(buff[:])
+			if err != nil {
+				l.connErr <- err
+				return
+			}
 		}
 	}(uapi)