diff --git a/mock/mock.go b/mock/mock.go index a13c37f3b..a1f941bf2 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -390,8 +390,7 @@ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, * for i, call := range m.ExpectedCalls { if call.Method == method { - _, diffCount := call.Arguments.Diff(arguments) - if diffCount == 0 { + if call.Arguments.matchCount(arguments) == 0 { expectedCall = call if call.Repeatability > -1 { return i, call @@ -405,7 +404,6 @@ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, * type matchCandidate struct { call *Call - mismatch string diffCount int } @@ -430,16 +428,14 @@ func (c matchCandidate) isBetterMatchThan(other matchCandidate) bool { return false } -func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) { +func (m *Mock) findClosestCall(method string, arguments ...interface{}) *Call { var bestMatch matchCandidate for _, call := range m.expectedCalls() { if call.Method == method { - - errInfo, tempDiffCount := call.Arguments.Diff(arguments) + tempDiffCount := call.Arguments.matchCount(arguments) tempCandidate := matchCandidate{ call: call, - mismatch: errInfo, diffCount: tempDiffCount, } if tempCandidate.isBetterMatchThan(bestMatch) { @@ -448,7 +444,7 @@ func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, } } - return bestMatch.call, bestMatch.mismatch + return bestMatch.call } func callString(method string, arguments Arguments, includeArgumentValues bool) string { @@ -512,10 +508,13 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen // a) this is a totally unexpected call to this method, // b) the arguments are not what was expected, or // c) the developer has forgotten to add an accompanying On...Return pair. - closestCall, mismatch := m.findClosestCall(methodName, arguments...) + closestCall := m.findClosestCall(methodName, arguments...) m.mutex.Unlock() if closestCall != nil { + // Format the diff outside the mutex to avoid deadlocks when + // arguments implement Stringer and call back into MethodCalled. + mismatch, _ := closestCall.Arguments.Diff(arguments) m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s\nat: %s\n", callString(methodName, arguments, true), callString(methodName, closestCall.Arguments, true), @@ -953,6 +952,66 @@ func (args Arguments) Is(objects ...interface{}) bool { return true } +// matchCount returns the number of argument differences without formatting +// output strings. This is safe to call while holding a mutex because it +// does not invoke user-defined methods like String() or GoString() that +// could call back into MethodCalled and cause a deadlock. +func (args Arguments) matchCount(objects []interface{}) int { + maxArgCount := len(args) + if len(objects) > maxArgCount { + maxArgCount = len(objects) + } + + var differences int + for i := 0; i < maxArgCount; i++ { + if len(objects) <= i || len(args) <= i { + differences++ + continue + } + actual := objects[i] + expected := args[i] + + if matcher, ok := expected.(argumentMatcher); ok { + func() { + defer func() { + if recover() != nil { + differences++ + } + }() + if !matcher.Matches(actual) { + differences++ + } + }() + } else { + switch expected := expected.(type) { + case anythingOfTypeArgument: + if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) { + differences++ + } + case *IsTypeArgument: + if reflect.TypeOf(actual) != expected.t { + differences++ + } + case *FunctionalOptionsArgument: + var name string + if len(expected.values) > 0 { + name = "[]" + reflect.TypeOf(expected.values[0]).String() + } + if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 { + differences++ + } else if ef, af := assertOpts(expected.values, actual); ef != "" || af != "" { + differences++ + } + default: + if !assert.ObjectsAreEqual(expected, Anything) && !assert.ObjectsAreEqual(actual, Anything) && !assert.ObjectsAreEqual(actual, expected) { + differences++ + } + } + } + } + return differences +} + // Diff gets a string describing the differences between the arguments // and the specified objects. // diff --git a/mock/mock_test.go b/mock/mock_test.go index 3dc9e0b1e..b1144e276 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -2462,6 +2462,35 @@ func TestIssue1785ArgumentWithMutatingStringer(t *testing.T) { m.AssertExpectations(t) } +// TestIssue1719StringerDeadlock verifies that MethodCalled does not deadlock +// when an argument's String() method calls back into MethodCalled. +// See https://github.com/stretchr/testify/issues/1719 +func TestIssue1719StringerDeadlock(t *testing.T) { + done := make(chan struct{}) + + go func() { + defer close(done) + + m := &Mock{} + m.On("String").Return("") + m.On("DoAThing", Anything).Return() + + // When DoAThing is called with the mock itself as an argument, + // Diff used to format the argument with %v, triggering String(), + // which calls MethodCalled("String") — deadlock because the mutex + // is already held by the outer MethodCalled("DoAThing"). + m.MethodCalled("DoAThing", m) + m.MethodCalled("String") + }() + + select { + case <-done: + // Success — no deadlock + case <-time.After(5 * time.Second): + t.Fatal("MethodCalled deadlocked when argument's String() calls MethodCalled") + } +} + func TestIssue1227AssertExpectationsForObjectsWithMock(t *testing.T) { mockT := &MockTestingT{} AssertExpectationsForObjects(mockT, Mock{})