@@ -18,13 +18,15 @@ import (
18
18
"bytes"
19
19
"errors"
20
20
"fmt"
21
+ "sync"
21
22
)
22
23
23
24
// callSet represents a set of expected calls, indexed by receiver and method
24
25
// name.
25
26
type callSet struct {
26
27
// Calls that are still expected.
27
- expected map [callSetKey ][]* Call
28
+ expected map [callSetKey ][]* Call
29
+ expectedMu * sync.Mutex
28
30
// Calls that have been exhausted.
29
31
exhausted map [callSetKey ][]* Call
30
32
// when set to true,
@@ -39,14 +41,16 @@ type callSetKey struct {
39
41
40
42
func newCallSet () * callSet {
41
43
return & callSet {
42
- expected : make (map [callSetKey ][]* Call ),
43
- exhausted : make (map [callSetKey ][]* Call ),
44
+ expected : make (map [callSetKey ][]* Call ),
45
+ expectedMu : & sync.Mutex {},
46
+ exhausted : make (map [callSetKey ][]* Call ),
44
47
}
45
48
}
46
49
47
50
func newOverridableCallSet () * callSet {
48
51
return & callSet {
49
52
expected : make (map [callSetKey ][]* Call ),
53
+ expectedMu : & sync.Mutex {},
50
54
exhausted : make (map [callSetKey ][]* Call ),
51
55
allowOverride : true ,
52
56
}
@@ -55,6 +59,10 @@ func newOverridableCallSet() *callSet {
55
59
// Add adds a new expected call.
56
60
func (cs callSet ) Add (call * Call ) {
57
61
key := callSetKey {call .receiver , call .method }
62
+
63
+ cs .expectedMu .Lock ()
64
+ defer cs .expectedMu .Unlock ()
65
+
58
66
m := cs .expected
59
67
if call .exhausted () {
60
68
m = cs .exhausted
@@ -70,6 +78,10 @@ func (cs callSet) Add(call *Call) {
70
78
// Remove removes an expected call.
71
79
func (cs callSet ) Remove (call * Call ) {
72
80
key := callSetKey {call .receiver , call .method }
81
+
82
+ cs .expectedMu .Lock ()
83
+ defer cs .expectedMu .Unlock ()
84
+
73
85
calls := cs .expected [key ]
74
86
for i , c := range calls {
75
87
if c == call {
@@ -85,6 +97,9 @@ func (cs callSet) Remove(call *Call) {
85
97
func (cs callSet ) FindMatch (receiver interface {}, method string , args []interface {}) (* Call , error ) {
86
98
key := callSetKey {receiver , method }
87
99
100
+ cs .expectedMu .Lock ()
101
+ defer cs .expectedMu .Unlock ()
102
+
88
103
// Search through the expected calls.
89
104
expected := cs .expected [key ]
90
105
var callsErrors bytes.Buffer
@@ -119,6 +134,9 @@ func (cs callSet) FindMatch(receiver interface{}, method string, args []interfac
119
134
120
135
// Failures returns the calls that are not satisfied.
121
136
func (cs callSet ) Failures () []* Call {
137
+ cs .expectedMu .Lock ()
138
+ defer cs .expectedMu .Unlock ()
139
+
122
140
failures := make ([]* Call , 0 , len (cs .expected ))
123
141
for _ , calls := range cs .expected {
124
142
for _ , call := range calls {
@@ -129,3 +147,19 @@ func (cs callSet) Failures() []*Call {
129
147
}
130
148
return failures
131
149
}
150
+
151
+ // Satisfied returns true in case all expected calls in this callSet are satisfied.
152
+ func (cs callSet ) Satisfied () bool {
153
+ cs .expectedMu .Lock ()
154
+ defer cs .expectedMu .Unlock ()
155
+
156
+ for _ , calls := range cs .expected {
157
+ for _ , call := range calls {
158
+ if ! call .satisfied () {
159
+ return false
160
+ }
161
+ }
162
+ }
163
+
164
+ return true
165
+ }
0 commit comments