mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-07-28 07:52:50 +02:00
device: add support for removing allowedips individually
This pairs with the recent change in wireguard-tools. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
6768090667
commit
d5359f52f0
3 changed files with 125 additions and 34 deletions
|
@ -223,19 +223,11 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
func (node *trieEntry) remove() {
|
||||||
table.mutex.Lock()
|
|
||||||
defer table.mutex.Unlock()
|
|
||||||
|
|
||||||
var next *list.Element
|
|
||||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
|
||||||
next = elem.Next()
|
|
||||||
node := elem.Value.(*trieEntry)
|
|
||||||
|
|
||||||
node.removeFromPeerEntries()
|
node.removeFromPeerEntries()
|
||||||
node.peer = nil
|
node.peer = nil
|
||||||
if node.child[0] != nil && node.child[1] != nil {
|
if node.child[0] != nil && node.child[1] != nil {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
bit := 0
|
bit := 0
|
||||||
if node.child[0] == nil {
|
if node.child[0] == nil {
|
||||||
|
@ -248,12 +240,12 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||||
*node.parent.parentBit = child
|
*node.parent.parentBit = child
|
||||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||||
node.zeroizePointers()
|
node.zeroizePointers()
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||||
if parent.peer != nil {
|
if parent.peer != nil {
|
||||||
node.zeroizePointers()
|
node.zeroizePointers()
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
child = parent.child[node.parent.parentBitType^1]
|
child = parent.child[node.parent.parentBitType^1]
|
||||||
if child != nil {
|
if child != nil {
|
||||||
|
@ -263,6 +255,37 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||||
node.zeroizePointers()
|
node.zeroizePointers()
|
||||||
parent.zeroizePointers()
|
parent.zeroizePointers()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
var node *trieEntry
|
||||||
|
var exact bool
|
||||||
|
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
ip := prefix.Addr().As16()
|
||||||
|
node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||||
|
} else if prefix.Addr().Is4() {
|
||||||
|
ip := prefix.Addr().As4()
|
||||||
|
node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
|
||||||
|
} else {
|
||||||
|
panic(errors.New("removing unknown address type"))
|
||||||
|
}
|
||||||
|
if !exact || node == nil || peer != node.peer {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
node.remove()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
|
var next *list.Element
|
||||||
|
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||||
|
next = elem.Next()
|
||||||
|
elem.Value.(*trieEntry).remove()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||||
|
|
|
@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) {
|
||||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||||
|
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||||
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||||
if p != peer {
|
if p != peer {
|
||||||
|
@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) {
|
||||||
allowedIPs.RemoveByPeer(a)
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 192, 168, 0, 1)
|
assertNEQ(a, 192, 168, 0, 1)
|
||||||
|
|
||||||
|
insert(a, 1, 0, 0, 0, 32)
|
||||||
|
insert(a, 192, 0, 0, 0, 24)
|
||||||
|
assertEQ(a, 1, 0, 0, 0)
|
||||||
|
assertEQ(a, 192, 0, 0, 1)
|
||||||
|
remove(a, 192, 0, 0, 0, 32)
|
||||||
|
assertEQ(a, 192, 0, 0, 1)
|
||||||
|
remove(nil, 192, 0, 0, 0, 24)
|
||||||
|
assertEQ(a, 192, 0, 0, 1)
|
||||||
|
remove(b, 192, 0, 0, 0, 24)
|
||||||
|
assertEQ(a, 192, 0, 0, 1)
|
||||||
|
remove(a, 192, 0, 0, 0, 24)
|
||||||
|
assertNEQ(a, 192, 0, 0, 1)
|
||||||
|
remove(a, 1, 0, 0, 0, 32)
|
||||||
|
assertNEQ(a, 1, 0, 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Test ported from kernel implementation:
|
/* Test ported from kernel implementation:
|
||||||
|
@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
||||||
|
var addr []byte
|
||||||
|
addr = append(addr, expand(a)...)
|
||||||
|
addr = append(addr, expand(b)...)
|
||||||
|
addr = append(addr, expand(c)...)
|
||||||
|
addr = append(addr, expand(d)...)
|
||||||
|
allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||||
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
var addr []byte
|
var addr []byte
|
||||||
addr = append(addr, expand(a)...)
|
addr = append(addr, expand(a)...)
|
||||||
|
@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
|
var addr []byte
|
||||||
|
addr = append(addr, expand(a)...)
|
||||||
|
addr = append(addr, expand(b)...)
|
||||||
|
addr = append(addr, expand(c)...)
|
||||||
|
addr = append(addr, expand(d)...)
|
||||||
|
p := allowedIPs.Lookup(addr)
|
||||||
|
if p == peer {
|
||||||
|
t.Error("Assert NEQ failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
|
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
|
||||||
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
|
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
|
||||||
insert(e, 0, 0, 0, 0, 0)
|
insert(e, 0, 0, 0, 0, 0)
|
||||||
|
@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
|
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
|
||||||
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
|
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
|
||||||
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
|
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
|
||||||
|
|
||||||
|
insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||||
|
insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||||
|
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||||
|
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||||
|
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
|
||||||
|
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||||
|
remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||||
|
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||||
|
remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||||
|
assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||||
|
remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
|
||||||
|
assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
|
||||||
|
remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||||
|
assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||||
|
remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
|
||||||
|
assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
|
||||||
}
|
}
|
||||||
|
|
|
@ -497,7 +497,14 @@ func (device *Device) handlePeerLine(
|
||||||
device.allowedips.RemoveByPeer(peer.Peer)
|
device.allowedips.RemoveByPeer(peer.Peer)
|
||||||
|
|
||||||
case "allowed_ip":
|
case "allowed_ip":
|
||||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
add := true
|
||||||
|
verb := "Adding"
|
||||||
|
if len(value) > 0 && value[0] == '-' {
|
||||||
|
add = false
|
||||||
|
verb = "Removing"
|
||||||
|
value = value[1:]
|
||||||
|
}
|
||||||
|
device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
|
||||||
prefix, err := netip.ParsePrefix(value)
|
prefix, err := netip.ParsePrefix(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||||
|
@ -505,7 +512,11 @@ func (device *Device) handlePeerLine(
|
||||||
if peer.dummy {
|
if peer.dummy {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if add {
|
||||||
device.allowedips.Insert(prefix, peer.Peer)
|
device.allowedips.Insert(prefix, peer.Peer)
|
||||||
|
} else {
|
||||||
|
device.allowedips.Remove(prefix, peer.Peer)
|
||||||
|
}
|
||||||
|
|
||||||
case "protocol_version":
|
case "protocol_version":
|
||||||
if value != "1" {
|
if value != "1" {
|
||||||
|
|
Loading…
Add table
Reference in a new issue