diff --git a/mock/mock.go b/mock/mock.go index 7f4d28d5e..32629896f 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -203,11 +203,29 @@ func (c *Call) Maybe() *Call { // On("MyMethod", 1).Return(nil). // On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) // +// See also [Call.OnFn] +// //go:noinline func (c *Call) On(methodName string, arguments ...interface{}) *Call { return c.Parent.On(methodName, arguments...) } +// OnFn chains a new expectation description onto the mcked interface using +// a function reference instead of a string method name. +// for example: +// +// mock. +// OnFn(mocked.MyMethod, 1).Return(nil). +// OnFn(mocked.MyOtherMethod, 'a', 'b', 'c').Return(errors.New("Some Error")) +// +// The `method` argument must be a function; otherwise, this call will panic. +// The function name is resolved using reflection and runtime information. +// +//go:noinline +func (c *Call) OnFn(method interface{}, args ...interface{}) *Call { + return c.Parent.On(runtimeMethodName(method), args...) +} + // Unset removes all mock handlers that satisfy the call instance arguments from being // called. Only supported on call instances with static input arguments. // @@ -366,6 +384,8 @@ func (m *Mock) fail(format string, args ...interface{}) { // being called. // // Mock.On("MyMethod", arg1, arg2) +// +// See also [Mock.OnFn] func (m *Mock) On(methodName string, arguments ...interface{}) *Call { for _, arg := range arguments { if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { @@ -381,6 +401,18 @@ func (m *Mock) On(methodName string, arguments ...interface{}) *Call { return c } +// OnFn starts a description of an expectation of the specified method +// being called using a function reference instead of a string method name. +// +// Mock.OnFn(mocked.MyMethod, arg1, arg2) +// +// The `method` argument must be a function; otherwise, OnFn will panic. +// The function name is determined using reflection and runtime information, +// and then passed to [Mock.On](methodName, args...). +func (m *Mock) OnFn(method interface{}, args ...interface{}) *Call { + return m.On(runtimeMethodName(method), args...) +} + // /* // Recording and responding to activity // */ @@ -1303,6 +1335,20 @@ func funcName(f *runtime.Func) string { return splitted[len(splitted)-1] } +func runtimeMethodName(f interface{}) string { + t := reflect.TypeOf(f) + + if t.Kind() != reflect.Func { + panic("not a function") + } + + fname := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() + + parts := strings.Split(fname, ".") + + return strings.Split(parts[len(parts)-1], "-")[0] +} + func isFuncSame(f1, f2 *runtime.Func) bool { f1File, f1Loc := f1.FileLine(f1.Entry()) f2File, f2Loc := f2.FileLine(f2.Entry()) diff --git a/mock/mock_test.go b/mock/mock_test.go index 3dc9e0b1e..81b0069f7 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -558,6 +558,17 @@ func Test_Mock_On_WithFuncTypeArg(t *testing.T) { }) } +func Test_Mock_OnFn(t *testing.T) { + t.Parallel() + + // make a test impl object + var mockedService = new(TestExampleImplementation) + + c := mockedService.OnFn(mockedService.TheExampleMethod) + assert.Equal(t, []*Call{c}, mockedService.ExpectedCalls) + assert.Equal(t, "TheExampleMethod", c.Method) +} + func Test_Mock_Unset(t *testing.T) { t.Parallel()