diff --git a/device/pools.go b/device/pools.go index 94f3dc7..55d2be7 100644 --- a/device/pools.go +++ b/device/pools.go @@ -7,14 +7,13 @@ package device import ( "sync" - "sync/atomic" ) type WaitPool struct { pool sync.Pool cond sync.Cond lock sync.Mutex - count atomic.Uint32 + count uint32 // Get calls not yet Put back max uint32 } @@ -27,10 +26,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool { func (p *WaitPool) Get() any { if p.max != 0 { p.lock.Lock() - for p.count.Load() >= p.max { + for p.count >= p.max { p.cond.Wait() } - p.count.Add(1) + p.count++ p.lock.Unlock() } return p.pool.Get() @@ -41,7 +40,9 @@ func (p *WaitPool) Put(x any) { if p.max == 0 { return } - p.count.Add(^uint32(0)) + p.lock.Lock() + defer p.lock.Unlock() + p.count-- p.cond.Signal() } diff --git a/device/pools_test.go b/device/pools_test.go index 82d7493..538230b 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -32,7 +32,9 @@ func TestWaitPool(t *testing.T) { wg.Add(workers) var max atomic.Uint32 updateMax := func() { - count := p.count.Load() + p.lock.Lock() + count := p.count + p.lock.Unlock() if count > p.max { t.Errorf("count (%d) > max (%d)", count, p.max) }