mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-07-28 16:02:50 +02:00
device: fix WaitPool sync.Cond usage
The sync.Locker used with a sync.Cond must be acquired when changing
the associated condition, otherwise there is a window within
sync.Cond.Wait() where a wake-up may be missed.
Fixes: 4846070
("device: use a waiting sync.Pool instead of a channel")
Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
27e661d68e
commit
deedce495a
2 changed files with 9 additions and 6 deletions
|
@ -7,14 +7,13 @@ package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type WaitPool struct {
|
type WaitPool struct {
|
||||||
pool sync.Pool
|
pool sync.Pool
|
||||||
cond sync.Cond
|
cond sync.Cond
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
count atomic.Uint32
|
count uint32 // Get calls not yet Put back
|
||||||
max uint32
|
max uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
|
||||||
func (p *WaitPool) Get() any {
|
func (p *WaitPool) Get() any {
|
||||||
if p.max != 0 {
|
if p.max != 0 {
|
||||||
p.lock.Lock()
|
p.lock.Lock()
|
||||||
for p.count.Load() >= p.max {
|
for p.count >= p.max {
|
||||||
p.cond.Wait()
|
p.cond.Wait()
|
||||||
}
|
}
|
||||||
p.count.Add(1)
|
p.count++
|
||||||
p.lock.Unlock()
|
p.lock.Unlock()
|
||||||
}
|
}
|
||||||
return p.pool.Get()
|
return p.pool.Get()
|
||||||
|
@ -41,7 +40,9 @@ func (p *WaitPool) Put(x any) {
|
||||||
if p.max == 0 {
|
if p.max == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.count.Add(^uint32(0))
|
p.lock.Lock()
|
||||||
|
defer p.lock.Unlock()
|
||||||
|
p.count--
|
||||||
p.cond.Signal()
|
p.cond.Signal()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) {
|
||||||
wg.Add(workers)
|
wg.Add(workers)
|
||||||
var max atomic.Uint32
|
var max atomic.Uint32
|
||||||
updateMax := func() {
|
updateMax := func() {
|
||||||
count := p.count.Load()
|
p.lock.Lock()
|
||||||
|
count := p.count
|
||||||
|
p.lock.Unlock()
|
||||||
if count > p.max {
|
if count > p.max {
|
||||||
t.Errorf("count (%d) > max (%d)", count, p.max)
|
t.Errorf("count (%d) > max (%d)", count, p.max)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue