Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ type Call struct {
// decoders.
RunFn func(Arguments)

// Holds a handler to a function that will be called before returning.
returnFn func(Arguments) Arguments

// PanicMsg holds msg to be used to mock panic on the function call
// if the PanicMsg is set to a non nil string the function call will panic
// irrespective of other settings
Expand Down Expand Up @@ -110,6 +113,7 @@ func (c *Call) Return(returnArguments ...interface{}) *Call {
defer c.unlock()

c.ReturnArguments = returnArguments
c.returnFn = nil

return c
}
Expand Down Expand Up @@ -187,6 +191,19 @@ func (c *Call) Run(fn func(args Arguments)) *Call {
return c
}

// ReturnFn sets a handler to be called before returning.
//
// Mock.On("MyMethod", arg1, arg2).ReturnFn(func(args Arguments) Arguments {
// return Arguments{args.Get(0) + args.Get(1)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example should be valid Go, did you mean Int rather than Get?

Suggested change
// return Arguments{args.Get(0) + args.Get(1)}
// return Arguments{args.Int(0) + args.Int(1)}

// })
func (c *Call) ReturnFn(fn func(args Arguments) Arguments) *Call {
c.lock()
defer c.unlock()
c.returnFn = fn
c.ReturnArguments = nil
return c
}

// Maybe allows the method call to be optional. Not calling an optional method
// will not cause an error while asserting expectations
func (c *Call) Maybe() *Call {
Expand Down Expand Up @@ -584,6 +601,14 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen
returnArgs := call.ReturnArguments
m.mutex.Unlock()

m.mutex.Lock()
returnFn := call.returnFn
m.mutex.Unlock()

if returnFn != nil {
Copy link
Collaborator

@brackendawson brackendawson May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because call.ReturnArguments is exported it should probably be checked by this condition rather than returnFn, eg:

	m = &myMock{}
	m.On("Do").ReturnFn(func(args mock.Arguments) mock.Arguments { return mock.Arguments{"two"} })
	m.ExpectedCalls[0].ReturnArguments = mock.Arguments{"one"}
	assert.Equal(t, "one", m.Do())

ie. Call.ReturnArguments should always override Call.returnFn

returnArgs = returnFn(arguments)
}

return returnArgs
}

Expand Down
144 changes: 142 additions & 2 deletions mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type TestExampleImplementation struct {

func (i *TestExampleImplementation) TheExampleMethod(a, b, c int) (int, error) {
args := i.Called(a, b, c)
return args.Int(0), errors.New("Whoops")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looked wrong -- not sure why this would always return an error?

return args.Int(0), args.Error(1)
}

type options struct {
Expand Down Expand Up @@ -834,6 +834,146 @@ func Test_Mock_Return_Run_Out_Of_Order(t *testing.T) {
assert.NotNil(t, call.Run)
}

func Test_Mock_ReturnFn(t *testing.T) {

// make a test impl object
var mockedService = new(TestExampleImplementation)

t.Run("can dynamically set the return values", func(t *testing.T) {
counter := 0
mockedService.On("TheExampleMethod", Anything, Anything, Anything).
ReturnFn(func(args Arguments) Arguments {
counter++
a, b, c := args[0].(int), args[1].(int), args[2].(int)
assert.IsType(t, 1, a)
assert.IsType(t, 1, b)
assert.IsType(t, 1, c)
return Arguments{counter, nil}
}).
Twice()

answer, err := mockedService.TheExampleMethod(2, 4, 5)
assert.NoError(t, err)
assert.Equal(t, 1, answer)

answer, err = mockedService.TheExampleMethod(44, 4, 5)
assert.NoError(t, err)
assert.Equal(t, 2, answer)
})

t.Run("handles func(Args) Args style", func(t *testing.T) {
mockedService.On("TheExampleMethod", Anything, Anything, Anything).
ReturnFn(func(args Arguments) Arguments {
return []interface{}{args[0].(int) + 40, fmt.Errorf("hmm")}
}).
Twice()

answer, err := mockedService.TheExampleMethod(2, 4, 5)
assert.Error(t, err, "hmm")
assert.Equal(t, 42, answer)

answer, err = mockedService.TheExampleMethod(44, 4, 5)
assert.Error(t, err, "hmm")
assert.Equal(t, 84, answer)
})

t.Run("handles pointer input args", func(t *testing.T) {
mockedService.On("TheExampleMethod3", Anything).ReturnFn(func(arguments Arguments) Arguments {
et := arguments[0].(*ExampleType)
if et == nil {
return Arguments{errors.New("error")}
}
return Arguments{nil}
}).Twice()

err := mockedService.TheExampleMethod3(nil)
assert.Error(t, err)

err = mockedService.TheExampleMethod3(&ExampleType{})
assert.NoError(t, err)
})

t.Run("handles variadic input args", func(t *testing.T) {
mockedService.
On("TheExampleMethodMixedVariadic", Anything, Anything).
ReturnFn(func(args Arguments) Arguments {
a, b := args[0].(int), args[1].([]int)
var sum = a
for _, v := range b {
sum += v
}
return Arguments{fmt.Errorf("%v", sum)}
})

assert.Equal(t, "42", mockedService.TheExampleMethodMixedVariadic(40, 1, 1).Error())
assert.Equal(t, "40", mockedService.TheExampleMethodMixedVariadic(40).Error())
})

t.Run("allows all of Run and RunWithReturn and Return to be used", func(t *testing.T) {
mockedService.On("TheExampleMethod", Anything, Anything, Anything).
Run(func(args Arguments) {
a := args[0].(int)
assert.IsType(t, 1, a)
}).
ReturnFn(func(args Arguments) Arguments {
a := args[0].(int)
return Arguments{a + 40, fmt.Errorf("hmm")}
}).
Return(80, nil)

answer, err := mockedService.TheExampleMethod(2, 4, 5)
assert.Equal(t, 80, answer)
assert.NoError(t, err)
})
}

func Test_Mock_Return_RespectOrder(t *testing.T) {
tests := []struct {
name string
arrange func() *TestExampleImplementation
expected int
}{
{
name: "should take the last return value",
arrange: func() *TestExampleImplementation {
m := new(TestExampleImplementation)
m.On("TheExampleMethod", Anything, Anything, Anything).Return(1, nil).Return(2, nil)
return m
},
expected: 2,
},
{
name: "should take the last return value with returnFn",
arrange: func() *TestExampleImplementation {
m := new(TestExampleImplementation)
m.On("TheExampleMethod", Anything, Anything, Anything).Return(1, nil).ReturnFn(func(args Arguments) Arguments { return Arguments{2, nil} })
return m
},
expected: 2,
},
{
name: "should take the last return value with returnFn and return",
arrange: func() *TestExampleImplementation {
m := new(TestExampleImplementation)
m.On("TheExampleMethod", Anything, Anything, Anything).ReturnFn(func(args Arguments) Arguments {
return Arguments{1, nil}
}).Return(2, nil)
return m
},
expected: 2,
},
}
// run the tests
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
m := test.arrange()
actualResult, actualError := m.TheExampleMethod(0, 0, 0)
assert.NoError(t, actualError)
assert.Equal(t, test.expected, actualResult)
})
}
}

func Test_Mock_Return_Once(t *testing.T) {

// make a test impl object
Expand Down Expand Up @@ -1341,7 +1481,7 @@ func Test_Mock_Called_For_SetTime_Expectation(t *testing.T) {

var mockedService = new(TestExampleImplementation)

mockedService.On("TheExampleMethod", 1, 2, 3).Return(5, "6", true).Times(4)
mockedService.On("TheExampleMethod", 1, 2, 3).Return(5, nil).Times(4)

mockedService.TheExampleMethod(1, 2, 3)
mockedService.TheExampleMethod(1, 2, 3)
Expand Down