device: fix WaitPool sync.Cond usage

The sync.Locker used with a sync.Cond must be acquired when changing
the associated condition, otherwise there is a window within
sync.Cond.Wait() where a wake-up may be missed.

Fixes: 4846070 ("device: use a waiting sync.Pool instead of a channel")
Reviewed-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited 2024-06-27 08:43:41 -07:00 committed by jmwample
parent 27e661d68e
commit deedce495a
No known key found for this signature in database
2 changed files with 9 additions and 6 deletions

View file

@ -7,14 +7,13 @@ package device
import ( import (
"sync" "sync"
"sync/atomic"
) )
type WaitPool struct { type WaitPool struct {
pool sync.Pool pool sync.Pool
cond sync.Cond cond sync.Cond
lock sync.Mutex lock sync.Mutex
count atomic.Uint32 count uint32 // Get calls not yet Put back
max uint32 max uint32
} }
@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
func (p *WaitPool) Get() any { func (p *WaitPool) Get() any {
if p.max != 0 { if p.max != 0 {
p.lock.Lock() p.lock.Lock()
for p.count.Load() >= p.max { for p.count >= p.max {
p.cond.Wait() p.cond.Wait()
} }
p.count.Add(1) p.count++
p.lock.Unlock() p.lock.Unlock()
} }
return p.pool.Get() return p.pool.Get()
@ -41,7 +40,9 @@ func (p *WaitPool) Put(x any) {
if p.max == 0 { if p.max == 0 {
return return
} }
p.count.Add(^uint32(0)) p.lock.Lock()
defer p.lock.Unlock()
p.count--
p.cond.Signal() p.cond.Signal()
} }

View file

@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) {
wg.Add(workers) wg.Add(workers)
var max atomic.Uint32 var max atomic.Uint32
updateMax := func() { updateMax := func() {
count := p.count.Load() p.lock.Lock()
count := p.count
p.lock.Unlock()
if count > p.max { if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max) t.Errorf("count (%d) > max (%d)", count, p.max)
} }