Skip to content

Commit c03b3d5

Browse files
committed
fix potential race in listeners and add tests for that case
1 parent f9fcb1c commit c03b3d5

File tree

3 files changed

+104
-14
lines changed

3 files changed

+104
-14
lines changed

dequeue.go

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,27 @@ type BlockingDequeue[T any] struct {
1414

1515
writeCond *sync.Cond // condition used to lock and notify about writing to the encapsulated list
1616
capacityLock *sync.RWMutex // lock used to protect the capacity
17+
onFullLock *sync.RWMutex // lock used to protect the onFull callback
18+
onEmptyLock *sync.RWMutex // lock used to protect the onEmpty callback
1719

1820
capacity int
1921

20-
OnFull func() // Optional callback function invoked when the dequeue is full
21-
OnEmpty func() // Optional callback function invoked when the dequeue is empty
22+
onFull func() // Optional callback function invoked when the dequeue is full
23+
onEmpty func() // Optional callback function invoked when the dequeue is empty
2224
}
2325

2426
// Creates a new blocking dequeue with infinite capacity.
2527
// The dequeue MUST only be created using this method.
2628
func NewBlockingDequeue[T any]() *BlockingDequeue[T] {
2729
d := new(BlockingDequeue[T])
2830
d.list = list.New()
31+
2932
d.writeCond = sync.NewCond(&sync.Mutex{})
33+
3034
d.capacityLock = &sync.RWMutex{}
35+
d.onFullLock = &sync.RWMutex{}
36+
d.onEmptyLock = &sync.RWMutex{}
37+
3138
return d
3239
}
3340

@@ -47,9 +54,12 @@ func (d *BlockingDequeue[T]) PushFront(item T) {
4754
// Notify the consumer that an item has been added
4855
defer d.writeCond.Broadcast()
4956

57+
d.onFullLock.RLock()
58+
defer d.onFullLock.RUnlock()
59+
5060
// Call the OnFull callback if the dequeue is full
51-
if d.isFull_unsafe() && d.OnFull != nil {
52-
d.OnFull()
61+
if d.isFull_unsafe() && d.onFull != nil {
62+
d.onFull()
5363
}
5464
}
5565

@@ -68,8 +78,11 @@ func (d *BlockingDequeue[T]) PushBack(item T) {
6878
defer d.writeCond.Broadcast()
6979

7080
// Call the OnFull callback if the dequeue is full
71-
if d.isFull_unsafe() && d.OnFull != nil {
72-
d.OnFull()
81+
d.onFullLock.RLock()
82+
defer d.onFullLock.RUnlock()
83+
84+
if d.isFull_unsafe() && d.onFull != nil {
85+
d.onFull()
7386
}
7487
}
7588

@@ -88,8 +101,11 @@ func (d *BlockingDequeue[T]) PopFront() T {
88101
defer d.writeCond.Broadcast()
89102

90103
// Call the OnEmpty callback if the dequeue is empty
91-
if d.isEmpty_unsafe() && d.OnEmpty != nil {
92-
d.OnEmpty()
104+
d.onEmptyLock.RLock()
105+
defer d.onEmptyLock.RUnlock()
106+
107+
if d.isEmpty_unsafe() && d.onEmpty != nil {
108+
d.onEmpty()
93109
}
94110

95111
return item
@@ -110,8 +126,11 @@ func (d *BlockingDequeue[T]) PopBack() T {
110126
defer d.writeCond.Broadcast()
111127

112128
// Call the OnEmpty callback if the dequeue is empty
113-
if d.isEmpty_unsafe() && d.OnEmpty != nil {
114-
d.OnEmpty()
129+
d.onEmptyLock.RLock()
130+
defer d.onEmptyLock.RUnlock()
131+
132+
if d.isEmpty_unsafe() && d.onEmpty != nil {
133+
d.onEmpty()
115134
}
116135

117136
return item
@@ -143,6 +162,26 @@ func (d *BlockingDequeue[T]) PeekBack() T {
143162
return element.Value.(T)
144163
}
145164

165+
// ================================[Listeners related]================================
166+
167+
// Set the callback function invoked when the dequeue is full.
168+
// Attempting to update the dequeue in the callback function will cause a deadlock.
169+
func (d *BlockingDequeue[T]) SetOnFull(onFull func()) {
170+
d.onFullLock.Lock()
171+
defer d.onFullLock.Unlock()
172+
173+
d.onFull = onFull
174+
}
175+
176+
// Set the callback function invoked when the dequeue is empty.
177+
// Attempting to update the dequeue in the callback function will cause a deadlock.
178+
func (d *BlockingDequeue[T]) SetOnEmpty(onEmpty func()) {
179+
d.onEmptyLock.Lock()
180+
defer d.onEmptyLock.Unlock()
181+
182+
d.onEmpty = onEmpty
183+
}
184+
146185
// ================================[Size/Capacity related]================================
147186

148187
// Set dequeue capacity, if capacity is 0, dequeue is infinite.

dequeue_integration_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,54 @@ func TestConcurrentCapacityChange(t *testing.T) {
235235

236236
// No checks need to be done, if the test is here without race conditions, it passed
237237
}
238+
239+
func TestConcurrentUpdatesToOnEmptyListener(t *testing.T) {
240+
dequeue := NewBlockingDequeue[int]()
241+
242+
getDummyCallback := func() func() { return func() {} }
243+
244+
// Update onEmpty callback while it's being called
245+
go func() {
246+
time.Sleep(10 * time.Millisecond)
247+
for i := 0; i < 100; i++ {
248+
dequeue.SetOnEmpty(getDummyCallback())
249+
}
250+
}()
251+
252+
// Update the dequeue to call the onEmpty callback
253+
for i := 0; i < 100; i++ {
254+
go func(val int) {
255+
dequeue.PushBack(val)
256+
dequeue.PopBack()
257+
258+
dequeue.PushFront(val)
259+
dequeue.PopFront()
260+
}(i)
261+
}
262+
}
263+
264+
func TestConcurrentUpdatesToOnFullListener(t *testing.T) {
265+
dequeue := NewBlockingDequeue[int]()
266+
dequeue.SetCapacity(1)
267+
268+
getDummyCallback := func() func() { return func() {} }
269+
270+
// Update onEmpty callback while it's being called
271+
go func() {
272+
time.Sleep(10 * time.Millisecond)
273+
for i := 0; i < 100; i++ {
274+
dequeue.SetOnFull(getDummyCallback())
275+
}
276+
}()
277+
278+
// Update the dequeue to call the onEmpty callback
279+
for i := 0; i < 100; i++ {
280+
go func(val int) {
281+
dequeue.PushBack(val)
282+
dequeue.PopBack()
283+
284+
dequeue.PushFront(val)
285+
dequeue.PopFront()
286+
}(i)
287+
}
288+
}

dequeue_unit_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,9 @@ func TestOnEmpty(t *testing.T) {
386386
dequeue := NewBlockingDequeue[int]()
387387

388388
called := 0
389-
dequeue.OnEmpty = func() {
389+
dequeue.SetOnEmpty(func() {
390390
called++
391-
}
391+
})
392392

393393
dequeue.PushBack(1)
394394
dequeue.PopFront()
@@ -410,9 +410,9 @@ func TestOnFull(t *testing.T) {
410410
dequeue.SetCapacity(3)
411411

412412
called := 0
413-
dequeue.OnFull = func() {
413+
dequeue.SetOnFull(func() {
414414
called++
415-
}
415+
})
416416

417417
dequeue.PushBack(1)
418418
dequeue.PushBack(2)

0 commit comments

Comments
 (0)