diff --git a/common/testing/await/report.go b/common/testing/await/report.go index 23d6f40b70..f0720b9323 100644 --- a/common/testing/await/report.go +++ b/common/testing/await/report.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" "testing" + "text/tabwriter" "time" ) @@ -21,10 +22,16 @@ type attemptFailure struct { } type timeoutReport struct { - effectiveTimeout time.Duration - attempts int - attemptTimeouts int - failures []attemptFailure + effectiveTimeout time.Duration + configuredTimeout time.Duration + attemptTimeout time.Duration + testExtensionReport string + deadlineCause string + attempts int + attemptTimeouts int + attemptDurationSum time.Duration + attemptDurationMax time.Duration + failures []attemptFailure } func (r *timeoutReport) nextPoll() { @@ -41,18 +48,98 @@ func (r *timeoutReport) recordAttemptTimeout() { r.attemptTimeouts++ } +func (r *timeoutReport) recordAttemptDuration(d time.Duration) { + r.attemptDurationSum += d + r.attemptDurationMax = max(r.attemptDurationMax, d) +} + func (r timeoutReport) reportAttemptErrors(tb testing.TB) { reportAttemptErrors(tb, r.failures) } func (r timeoutReport) reportTimeout(tb testing.TB, funcName, timeoutMsg string) { r.reportAttemptErrors(tb) - message := fmt.Sprintf("condition not satisfied after %v", r.effectiveTimeout) + message := fmt.Sprintf("condition not satisfied after %v", reportDuration(r.effectiveTimeout)) if timeoutMsg != "" { - message = fmt.Sprintf("%s (not satisfied after %v)", timeoutMsg, r.effectiveTimeout) + message = fmt.Sprintf("%s (not satisfied after %v)", timeoutMsg, reportDuration(r.effectiveTimeout)) + } + var details strings.Builder + detailWriter := tabwriter.NewWriter(&details, 0, 0, 1, ' ', 0) + var detailErr error + writeDetail := func(label, value string) { + if detailErr != nil { + return + } + _, detailErr = fmt.Fprintf(detailWriter, " %s\t= %s\n", label, value) + } + + hasAttemptFailures := len(r.failures) > 0 || r.attemptTimeouts > 0 + shortenedTimeout := r.configuredTimeout-r.effectiveTimeout > time.Millisecond + if r.configuredTimeout > 0 && (!hasAttemptFailures || shortenedTimeout) { + value := reportDuration(r.effectiveTimeout) + if shortenedTimeout { + value += fmt.Sprintf(" (configured %v", reportDuration(r.configuredTimeout)) + if r.deadlineCause != "" { + value += "; limited by " + r.deadlineCause + } + value += ")" + } + writeDetail("await timeout", value) + } + writeDetail("attempts", fmt.Sprintf("%d", r.attempts)) + if r.attemptTimeouts > 0 { + writeDetail("attempt timeout", fmt.Sprintf("%d (configured as %v)", r.attemptTimeouts, reportDuration(r.attemptTimeout))) + } + if r.attempts > 0 { + writeDetail( + "attempt duration", + fmt.Sprintf( + "avg %v, max %v", + reportDuration(r.attemptDurationSum/time.Duration(r.attempts)), + reportDuration(r.attemptDurationMax), + ), + ) + } + if r.attemptTimeouts == 0 && r.deadlineCause != "" { + writeDetail("last failure", r.deadlineCause) + } + if r.testExtensionReport != "" { + if detailErr == nil { + _, detailErr = fmt.Fprintln(detailWriter, indentDetail(r.testExtensionReport)) + } + } + if detailErr != nil { + tb.Fatalf("%s: failed to render timeout report: %v", funcName, detailErr) + return + } + if err := detailWriter.Flush(); err != nil { + tb.Fatalf("%s: failed to render timeout report: %v", funcName, err) + return + } + tb.Fatalf("%s: %s\ndetails:\n%s", funcName, message, strings.TrimSuffix(details.String(), "\n")) +} + +func reportDuration(d time.Duration) string { + if d > -time.Millisecond && d < time.Millisecond { + rounded := d.Round(time.Microsecond) + if rounded == 0 { + return "0µs" + } + return rounded.String() + } + return d.Round(time.Millisecond).String() +} + +func indentDetail(s string) string { + var b strings.Builder + for line := range strings.SplitSeq(s, "\n") { + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(" ") + b.WriteString(line) } - tb.Fatalf("%s: %s\ndetails:\n attempts = %d\n attempt timeouts = %d", - funcName, message, r.attempts, r.attemptTimeouts) + return b.String() } func reportAttemptErrors(tb testing.TB, failures []attemptFailure) { diff --git a/common/testing/await/report_test.go b/common/testing/await/report_test.go deleted file mode 100644 index eecf706b00..0000000000 --- a/common/testing/await/report_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package await - -import ( - "fmt" - "strings" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestReportTimeout(t *testing.T) { - t.Run("without message", func(t *testing.T) { - tb := newReportRecordingTB() - - timeoutReport{ - effectiveTimeout: time.Second, - attempts: 3, - attemptTimeouts: 2, - }.reportTimeout(tb, "Require", "") - - require.Equal(t, strings.Join([]string{ - "Require: condition not satisfied after 1s", - "details:", - " attempts = 3", - " attempt timeouts = 2", - }, "\n"), tb.fatals()) - }) - - t.Run("with message", func(t *testing.T) { - tb := newReportRecordingTB() - - timeoutReport{ - effectiveTimeout: 2 * time.Second, - attempts: 4, - attemptTimeouts: 1, - }.reportTimeout(tb, "Require", "workflow wf-123 not ready") - - require.Equal(t, strings.Join([]string{ - "Require: workflow wf-123 not ready (not satisfied after 2s)", - "details:", - " attempts = 4", - " attempt timeouts = 1", - }, "\n"), tb.fatals()) - }) -} - -type reportRecordingTB struct { - testing.TB - mu sync.Mutex - fatalMessages []string -} - -func newReportRecordingTB() *reportRecordingTB { - return &reportRecordingTB{} -} - -func (r *reportRecordingTB) Helper() {} - -func (r *reportRecordingTB) Fatalf(format string, args ...any) { - r.mu.Lock() - defer r.mu.Unlock() - r.fatalMessages = append(r.fatalMessages, fmt.Sprintf(format, args...)) -} - -func (r *reportRecordingTB) fatals() string { - r.mu.Lock() - defer r.mu.Unlock() - return strings.Join(r.fatalMessages, "\n") -} diff --git a/common/testing/await/require_ctx.go b/common/testing/await/require_ctx.go index 3b143b16b6..36c29b7f22 100644 --- a/common/testing/await/require_ctx.go +++ b/common/testing/await/require_ctx.go @@ -100,26 +100,34 @@ func run( parentCtx = extension.Context() } deadline := start.Add(cfg.totalTimeout) + deadlineCause := "" if !extension.Deadline.IsZero() && extension.Deadline.Before(deadline) { deadline = extension.Deadline + deadlineCause = testcontextDeadlineCause(extension.Limit, parentIsTestContext) } // Cap at the parent context's deadline if it's earlier than our timeout. if parentDeadline, hasDeadline := parentCtx.Deadline(); hasDeadline && parentDeadline.Before(deadline) { deadline = parentDeadline + deadlineCause = testcontextDeadlineCause(extension.Limit, parentIsTestContext) } effectiveTimeout := max(0, time.Until(deadline)) awaitCtx, awaitCancel := context.WithDeadline(parentCtx, deadline) defer awaitCancel() - report := timeoutReport{effectiveTimeout: effectiveTimeout} + report := timeoutReport{ + effectiveTimeout: effectiveTimeout, + configuredTimeout: cfg.totalTimeout, + attemptTimeout: cfg.attemptTimeout, + testExtensionReport: extension.Report, + deadlineCause: deadlineCause, + } for { // Parent context was canceled while we were sleeping (not our deadline). if err := awaitCtx.Err(); err != nil && !deadlineReached(deadline) { - report.reportAttemptErrors(tb) - tb.Fatalf("%s: context canceled before condition was satisfied: %v", funcName, err) + failContextCanceled(tb, report, funcName, err) return } @@ -131,7 +139,9 @@ func run( t := &T{tb: tb, ctx: attemptCtx} // Run attempt. + attemptStart := time.Now() res := runAttempt(t, condition, attemptCancel, funcName, cancellable) + report.recordAttemptDuration(time.Since(attemptStart)) attemptCancel() if res.panicVal != nil { panic(res.panicVal) // propagate to caller @@ -165,13 +175,15 @@ func run( // Parent context was canceled during the attempt (not our deadline). if err := awaitCtx.Err(); err != nil && !deadlineReached(deadline) { - report.reportAttemptErrors(tb) - tb.Fatalf("%s: context canceled before condition was satisfied: %v", funcName, err) + failContextCanceled(tb, report, funcName, err) return } // Our deadline expired. if deadlineReached(deadline) { + if deadlineCause != "" { + extension.SuppressCleanupReport() + } report.reportTimeout(tb, funcName, cfg.timeoutMsg) return } @@ -186,6 +198,22 @@ func run( } } +func failContextCanceled(tb testing.TB, report timeoutReport, funcName string, err error) { + tb.Helper() + report.reportAttemptErrors(tb) + tb.Fatalf("%v", fmt.Errorf("%s: context canceled before condition was satisfied: %w", funcName, err)) +} + +func testcontextDeadlineCause(limit string, parentIsTestContext bool) string { + if limit != "" { + return limit + } + if parentIsTestContext { + return "test context deadline" + } + return "parent context deadline" +} + // attemptResult describes how an attempt terminated. Exactly one of the // following fields is set: // - panicVal != nil: condition panicked with a non-attemptFailed value; diff --git a/common/testing/await/require_ctx_test.go b/common/testing/await/require_ctx_test.go index 1872ea2857..9afcda36ec 100644 --- a/common/testing/await/require_ctx_test.go +++ b/common/testing/await/require_ctx_test.go @@ -8,6 +8,7 @@ import ( "sync" "sync/atomic" "testing" + "testing/synctest" "time" "github.com/stretchr/testify/require" @@ -153,110 +154,201 @@ func TestRequire_PollIntervalStartsAfterAttemptFinishes(t *testing.T) { } func TestRequire_FailureScenarios(t *testing.T) { - t.Run("reports timeout", func(t *testing.T) { + t.Run("retries failed attempts until await timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.For(t) - tb := newRecordingTB() - tb.run(func() { - await.Require(ctx, tb, func(t *await.T) { - t.Error("not ready") - }, time.Second, 100*time.Millisecond) + synctest.Test(t, func(t *testing.T) { + ctx := testcontext.For(t) + var attempts atomic.Int32 + + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + n := attempts.Add(1) + t.Errorf("attempt %d failed", n) + }, time.Second, 100*time.Millisecond) + }) + + require.True(t, tb.Failed()) + require.Equal(t, strings.Join([]string{ + "Require: condition not satisfied after 1s", + "details:", + " attempts = 11", + " attempt duration = avg 0µs, max 0µs", + }, "\n"), tb.fatals()) + require.Equal(t, strings.Join([]string{ + "attempt errors:", + "", + " --- attempt 1 ---", + " attempt 1 failed", + " ... 7 attempts omitted ...", + "", + " --- attempt 9 ---", + " attempt 9 failed", + "", + " --- attempt 10 ---", + " attempt 10 failed", + "", + " --- attempt 11 ---", + " attempt 11 failed", + }, "\n"), tb.errors()) }) - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "not satisfied after") }) - t.Run("cancels attempt context on timeout", func(t *testing.T) { + t.Run("cancels running attempt at await timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.For(t) - tb := newRecordingTB() - tb.run(func() { - await.Require(ctx, tb, func(t *await.T) { - <-t.Context().Done() - if t.Context().Err() != context.DeadlineExceeded { - t.Errorf("context error = %v", t.Context().Err()) - } - }, 2*time.Second, time.Second) + synctest.Test(t, func(t *testing.T) { + ctx := testcontext.For(t) + + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + <-t.Context().Done() + if t.Context().Err() != context.DeadlineExceeded { + t.Errorf("context error = %v", t.Context().Err()) + } + }, 2*time.Second, time.Second) + }) + + require.True(t, tb.Failed()) + require.Equal(t, strings.Join([]string{ + "Require: condition not satisfied after 2s", + "details:", + " await timeout = 2s", + " attempts = 1", + " attempt duration = avg 2s, max 2s", + }, "\n"), tb.fatals()) }) - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "not satisfied after") }) - t.Run("retries after attempt timeout until await timeout", func(t *testing.T) { - attemptTimeoutEnv := 50 * time.Millisecond - attemptTimeout := attemptTimeoutEnv * debug.TimeoutMultiplier - pollInterval := 100 * time.Millisecond - t.Setenv("TEMPORAL_AWAIT_ATTEMPT_TIMEOUT", attemptTimeoutEnv.String()) + t.Run("retries after attempt deadline expires", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + attemptTimeoutEnv := 50 * time.Millisecond + attemptTimeout := attemptTimeoutEnv * debug.TimeoutMultiplier + pollInterval := 100 * time.Millisecond + t.Setenv("TEMPORAL_AWAIT_ATTEMPT_TIMEOUT", attemptTimeoutEnv.String()) - ctx := testcontext.For(t) - var attempts atomic.Int32 - var firstAttemptRemaining time.Duration + ctx := testcontext.For(t) + var attempts atomic.Int32 + var firstAttemptRemaining time.Duration + + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + if attempts.Add(1) == 1 { + deadline, _ := t.Context().Deadline() + firstAttemptRemaining = time.Until(deadline) + } + <-t.Context().Done() + }, attemptTimeout+2*pollInterval, pollInterval) + }) - tb := newRecordingTB() - tb.run(func() { - await.Require(ctx, tb, func(t *await.T) { - if attempts.Add(1) == 1 { - deadline, _ := t.Context().Deadline() - firstAttemptRemaining = time.Until(deadline) - } - <-t.Context().Done() - }, attemptTimeout+2*pollInterval, pollInterval) + require.True(t, tb.Failed()) + require.Equal(t, strings.Join([]string{ + "Require: condition not satisfied after 250ms", + "details:", + " attempts = 3", + " attempt timeout = 2 (configured as 50ms)", + " attempt duration = avg 33ms, max 50ms", + }, "\n"), tb.fatals()) + require.Equal(t, attemptTimeout, firstAttemptRemaining) + require.Equal(t, int32(3), attempts.Load()) }) - - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "not satisfied after") - require.Positive(t, firstAttemptRemaining) - require.LessOrEqual(t, firstAttemptRemaining, attemptTimeout) - require.Greater(t, attempts.Load(), int32(1)) }) - t.Run("does not poll again after attempt consumes timeout", func(t *testing.T) { + t.Run("does not start another attempt after await timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.For(t) - var attempts atomic.Int32 + synctest.Test(t, func(t *testing.T) { + ctx := testcontext.For(t) + var attempts atomic.Int32 - tb := newRecordingTB() - tb.run(func() { - await.Require(ctx, tb, func(t *await.T) { - attempts.Add(1) - <-t.Context().Done() // block until timeout - }, time.Second, 100*time.Millisecond) + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + attempts.Add(1) + <-t.Context().Done() // block until timeout + }, time.Second, 100*time.Millisecond) + }) + + require.True(t, tb.Failed()) + require.Equal(t, strings.Join([]string{ + "Require: condition not satisfied after 1s", + "details:", + " await timeout = 1s", + " attempts = 1", + " attempt duration = avg 1s, max 1s", + }, "\n"), tb.fatals()) + require.Equal(t, int32(1), attempts.Load()) }) - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "not satisfied after") - require.Equal(t, int32(1), attempts.Load()) }) - t.Run("caps attempt context with parent deadline", func(t *testing.T) { + t.Run("reports parent context deadline as await limit", func(t *testing.T) { t.Parallel() - parentCtx, cancel := context.WithTimeout(testcontext.For(t), time.Second) - defer cancel() + synctest.Test(t, func(t *testing.T) { + parentCtx, cancel := context.WithTimeout(testcontext.For(t), time.Second) + defer cancel() + + tb := newRecordingTB() + tb.run(func() { + await.Require(parentCtx, tb, func(t *await.T) { + deadline, ok := t.Context().Deadline() + if !ok { + t.Error("missing deadline") + } + if time.Until(deadline) > time.Second { + t.Errorf("deadline = %v", deadline) + } + <-t.Context().Done() + if t.Context().Err() != context.DeadlineExceeded { + t.Errorf("context error = %v", t.Context().Err()) + } + }, 2*time.Second, time.Second) + }) - tb := newRecordingTB() - tb.run(func() { - await.Require(parentCtx, tb, func(t *await.T) { - deadline, ok := t.Context().Deadline() - if !ok { - t.Error("missing deadline") - } - if time.Until(deadline) > time.Second { - t.Errorf("deadline = %v", deadline) - } - <-t.Context().Done() - if t.Context().Err() != context.DeadlineExceeded { - t.Errorf("context error = %v", t.Context().Err()) - } - }, 2*time.Second, time.Second) + require.True(t, tb.Failed()) + require.Equal(t, strings.Join([]string{ + "Require: condition not satisfied after 1s", + "details:", + " await timeout = 1s (configured 2s; limited by parent context deadline)", + " attempts = 1", + " attempt duration = avg 1s, max 1s", + " last failure = parent context deadline", + }, "\n"), tb.fatals()) + }) + }) + + t.Run("reports test context extension cap as await limit", func(t *testing.T) { + t.Setenv("TEMPORAL_AWAIT_ATTEMPT_TIMEOUT", "10s") + + synctest.Test(t, func(t *testing.T) { + tb := newRecordingTB() + + ctx := testcontext.For(tb, testcontext.WithTimeout(5*time.Second)) + + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + <-t.Context().Done() + }, 3*time.Minute, time.Second) + }) + + require.True(t, tb.Failed()) + require.Equal(t, strings.Join([]string{ + "Require: condition not satisfied after 2m0s", + "details:", + " await timeout = 2m0s (configured 3m0s; limited by test context extension cap)", + " attempts = 11", + " attempt timeout = 10 (configured as 10s)", + " attempt duration = avg 10s, max 10s", + " ctx extensions = 1 (+1m55s total)", + " 1. +1m55s after 0µs", + }, "\n"), tb.fatals()) }) - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "not satisfied after") }) - t.Run("parent context cancellation stops polling", func(t *testing.T) { + t.Run("stops after parent context cancellation", func(t *testing.T) { t.Parallel() parentCtx, cancel := context.WithCancel(testcontext.For(t)) @@ -271,76 +363,37 @@ func TestRequire_FailureScenarios(t *testing.T) { cancel() }, time.Second, 100*time.Millisecond) }) + require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "context canceled before condition was satisfied") + require.Equal(t, "Require: context canceled before condition was satisfied: context canceled", tb.fatals()) require.Equal(t, int32(1), attempts.Load(), "expected cancellation to stop polling") }) - t.Run("reports all attempt errors on timeout", func(t *testing.T) { + t.Run("uses Requiref message on await timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.For(t) - var attempts atomic.Int32 - tb := newRecordingTB() - tb.run(func() { - await.Require(ctx, tb, func(t *await.T) { - if attempts.Add(1) == 1 { - t.Error("first attempt error") - return - } - <-t.Context().Done() - t.Error("last attempt error") - }, time.Second, 100*time.Millisecond) - }) - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "not satisfied after") - require.Equal(t, "attempt errors:\n\n --- attempt 1 ---\n first attempt error\n\n --- attempt 2 ---\n last attempt error", tb.errors()) - require.Equal(t, int32(2), attempts.Load()) - }) - - t.Run("truncates middle attempts when many fail", func(t *testing.T) { - t.Parallel() - - ctx := testcontext.For(t) - var attempts atomic.Int32 - tb := newRecordingTB() - tb.run(func() { - await.Require(ctx, tb, func(t *await.T) { - n := attempts.Add(1) - t.Errorf("attempt %d failed", n) - }, 400*time.Millisecond, 50*time.Millisecond) - }) - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "not satisfied after") + synctest.Test(t, func(t *testing.T) { + ctx := testcontext.For(t) - n := attempts.Load() - require.Greater(t, n, int32(4), "need >4 attempts to exercise truncation") - - errs := tb.errors() - require.Contains(t, errs, "attempt errors:\n\n --- attempt 1 ---\n attempt 1 failed\n") - require.Contains(t, errs, fmt.Sprintf("... %d attempts omitted ...", n-4)) - // Last three attempts present in order. - for i := n - 2; i <= n; i++ { - require.Contains(t, errs, fmt.Sprintf("--- attempt %d ---\n attempt %d failed", i, i)) - } - }) - - t.Run("Requiref includes message on timeout", func(t *testing.T) { - t.Parallel() + tb := newRecordingTB() + tb.run(func() { + await.Requiref(ctx, tb, func(t *await.T) { + t.Error("not ready") + }, time.Second, 100*time.Millisecond, "workflow %s not ready", "wf-123") + }) - ctx := testcontext.For(t) - tb := newRecordingTB() - tb.run(func() { - await.Requiref(ctx, tb, func(t *await.T) { - t.Error("not ready") - }, time.Second, 100*time.Millisecond, "workflow %s not ready", "wf-123") + require.True(t, tb.Failed()) + require.Equal(t, strings.Join([]string{ + "Requiref: workflow wf-123 not ready (not satisfied after 1s)", + "details:", + " attempts = 11", + " attempt duration = avg 0µs, max 0µs", + }, "\n"), tb.fatals()) }) - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "workflow wf-123 not ready") }) - t.Run("panic propagates", func(t *testing.T) { + t.Run("propagates panic from attempt", func(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "unexpected nil pointer", func() { @@ -350,37 +403,52 @@ func TestRequire_FailureScenarios(t *testing.T) { }) }) - t.Run("reports real TB misuse", func(t *testing.T) { + t.Run("detects real TB misuse", func(t *testing.T) { t.Parallel() for _, tc := range []struct { - name string - misuse func(*recordingTB) + name string + misuse func(*recordingTB) + expected string }{ - {"Fatal stops real TB", func(tb *recordingTB) { tb.Fatal("wrong t used") }}, - {"Errorf marks real TB failed", func(tb *recordingTB) { tb.Errorf("assert-style misuse") }}, + { + name: "Fatal stops real TB", + misuse: func(tb *recordingTB) { tb.Fatal("wrong t used") }, + expected: strings.Join([]string{ + "wrong t used", + "Require: the test was marked failed directly — use the *await.T passed to the callback, not s.T() or suite assertion methods", + }, "\n"), + }, + { + name: "Errorf marks real TB failed", + misuse: func(tb *recordingTB) { tb.Errorf("assert-style misuse") }, + expected: "Require: the test was marked failed directly — use the *await.T passed to the callback, not s.T() or suite assertion methods", + }, } { t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx := testcontext.For(t) + tb := newRecordingTB() tb.run(func() { await.Require(ctx, tb, func(_ *await.T) { tc.misuse(tb) }, time.Second, 100*time.Millisecond) }) + require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "use the *await.T") + require.Equal(t, tc.expected, tb.fatals()) }) } }) - t.Run("does not poll after prior failure", func(t *testing.T) { + t.Run("skips await after prior failure", func(t *testing.T) { t.Parallel() ctx := testcontext.For(t) conditionCalled := false + tb := newRecordingTB() tb.run(func() { tb.Errorf("previous failure") @@ -388,6 +456,7 @@ func TestRequire_FailureScenarios(t *testing.T) { conditionCalled = true }, time.Second, 100*time.Millisecond) }) + require.True(t, tb.Failed()) require.Empty(t, tb.fatals()) require.False(t, conditionCalled, "condition should not run when test already failed") @@ -438,27 +507,35 @@ func TestRequire_DeadlockDetected(t *testing.T) { elapsed := time.Since(start) require.True(t, tb.Failed()) require.Contains(t, tb.logs(), "soft deadlock") - require.Contains(t, tb.fatals(), "still running") - require.Contains(t, tb.fatals(), "does it honor t.Context()") + require.Equal(t, + "Require: condition still running 100ms past context cancellation — does it honor t.Context()? (1 attempts)", + tb.fatals(), + ) require.Less(t, elapsed, awaitTimeout, "should fail at hard deadlock, not wait the full await timeout (elapsed=%v)", elapsed) } func TestRequire_WaitsForInFlightAttemptOnTimeout(t *testing.T) { - t.Parallel() - - var finished atomic.Bool - ctx := testcontext.For(t) - tb := newRecordingTB() - tb.run(func() { - await.Require(ctx, tb, func(t *await.T) { - <-t.Context().Done() - finished.Store(true) - }, time.Second, time.Second) + synctest.Test(t, func(t *testing.T) { + var finished atomic.Bool + ctx := testcontext.For(t) + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + <-t.Context().Done() + finished.Store(true) + }, time.Second, time.Second) + }) + require.True(t, tb.Failed()) + require.Equal(t, strings.Join([]string{ + "Require: condition not satisfied after 1s", + "details:", + " await timeout = 1s", + " attempts = 1", + " attempt duration = avg 1s, max 1s", + }, "\n"), tb.fatals()) + require.True(t, finished.Load(), "Require returned before the running attempt exited") }) - require.True(t, tb.Failed()) - require.Contains(t, tb.fatals(), "not satisfied after") - require.True(t, finished.Load(), "Require returned before the running attempt exited") } // recordingTB is a minimal testing.TB implementation for testing failure scenarios. diff --git a/common/testing/testcontext/context.go b/common/testing/testcontext/context.go index 92560bbed1..4395fd6b83 100644 --- a/common/testing/testcontext/context.go +++ b/common/testing/testcontext/context.go @@ -2,6 +2,8 @@ package testcontext import ( "context" + "errors" + "fmt" "os" "sync" "testing" @@ -15,6 +17,9 @@ const ( defaultTimeout = 90 * time.Second defaultMaxTestTimeout = 2 * time.Minute testNameMetadataKey = "temporal-test-name" + + deadlineLimitTestContextCap = "test context extension cap" + deadlineLimitGoTestTimeout = "go test timeout" ) type contextStore struct { @@ -38,13 +43,21 @@ type contextDecorator struct { decorate func(context.Context) context.Context } +type extensionGrant struct { + duration time.Duration + elapsed time.Duration +} + // Extension describes the effect of an EnsureRemaining call. type Extension struct { Deadline time.Time Granted time.Duration + Report string + Limit string currentContext context.Context contexts []context.Context + state *contextState } // AppliesTo reports whether ctx is the test-scoped context observed by @@ -63,6 +76,13 @@ func (e Extension) Context() context.Context { return e.currentContext } +// SuppressCleanupReport marks the timeout as already reported. +func (e Extension) SuppressCleanupReport() { + if e.state != nil { + e.state.markTimeoutReported() + } +} + // DefaultTimeout returns the effective default timeout for test-scoped contexts. func DefaultTimeout() time.Duration { return effectiveTimeout(0) @@ -116,14 +136,19 @@ func AttachDecorator[K comparable](tb testing.TB, key K, decorator func(context. } type contextState struct { - mu sync.Mutex - ctx context.Context - contexts []context.Context - cancels []context.CancelFunc - testStart time.Time - originalTimeout time.Duration - decorators map[any]struct{} - orderedDecorators []contextDecorator + mu sync.Mutex + ctx context.Context + contexts []context.Context + cancels []context.CancelFunc + testStart time.Time + originalTimeout time.Duration + decorators map[any]struct{} + orderedDecorators []contextDecorator + extensionGrants []extensionGrant + extensionDenied int + extensionRequestedTotal time.Duration + deadlineLimit string + timeoutReported bool } func getContextState(tb testing.TB, timeout time.Duration) *contextState { @@ -141,7 +166,7 @@ func getContextState(tb testing.TB, timeout time.Duration) *contextState { originalTimeout: timeout, decorators: make(map[any]struct{}), } - st.setDeadline(tb, st.testStart.Add(timeout)) + st.setDeadline(tb, st.testStart.Add(timeout), "") testContexts.byTest[tb] = st tb.Cleanup(func() { @@ -150,8 +175,8 @@ func getContextState(tb testing.TB, timeout time.Duration) *contextState { testContexts.Lock() delete(testContexts.byTest, tb) testContexts.Unlock() - if err == context.DeadlineExceeded { - tb.Errorf("test exceeded timeout of %v", st.currentTimeout()) + if errors.Is(err, context.DeadlineExceeded) && st.markTimeoutReported() { + tb.Errorf("%v", st.timeoutExceededError(err)) } }) return st @@ -195,19 +220,31 @@ func (s *contextState) ensureRemainingUntil(tb testing.TB, target time.Time) Ext return s.extension(currentDeadline, 0) } + s.extensionRequestedTotal += requested + limit := "" capDeadline := s.testStart.Add(max(maxTestTimeout(), s.originalTimeout)) if capDeadline.Before(target) { target = capDeadline + limit = deadlineLimitTestContextCap } if goTestDeadline, ok := tb.Context().Deadline(); ok && goTestDeadline.Before(target) { target = goTestDeadline + limit = deadlineLimitGoTestTimeout } granted := target.Sub(currentDeadline) if granted <= 0 { + if limit != "" { + s.deadlineLimit = limit + } + s.extensionDenied++ return s.extension(currentDeadline, 0) } - s.setDeadline(tb, target) + s.extensionGrants = append(s.extensionGrants, extensionGrant{ + duration: granted, + elapsed: time.Since(s.testStart), + }) + s.setDeadline(tb, target, limit) return s.extension(target, granted) } @@ -215,8 +252,11 @@ func (s *contextState) extension(deadline time.Time, granted time.Duration) Exte return Extension{ Deadline: deadline, Granted: granted, + Report: s.extensionReportLocked(), + Limit: s.deadlineLimit, currentContext: s.ctx, contexts: append([]context.Context(nil), s.contexts...), + state: s, } } @@ -267,9 +307,9 @@ func (s *contextState) err() error { return s.ctx.Err() } -func (s *contextState) setDeadline(tb testing.TB, deadline time.Time) { +func (s *contextState) setDeadline(tb testing.TB, deadline time.Time, limit string) { if goTestDeadline, ok := tb.Context().Deadline(); ok && goTestDeadline.Before(deadline) { - deadline = goTestDeadline + limit = deadlineLimitGoTestTimeout } ctx, cancel := context.WithDeadline(tb.Context(), deadline) @@ -283,6 +323,7 @@ func (s *contextState) setDeadline(tb testing.TB, deadline time.Time) { s.ctx = ctx s.contexts = append(s.contexts, ctx) s.cancels = append(s.cancels, cancel) + s.deadlineLimit = limit } func (s *contextState) cancel() { @@ -293,14 +334,89 @@ func (s *contextState) cancel() { } } -func (s *contextState) currentTimeout() time.Duration { +func (s *contextState) timeoutExceededError(err error) error { + return fmt.Errorf("%w: %s", err, s.timeoutExceededMessage()) +} + +func (s *contextState) markTimeoutReported() bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.timeoutReported { + return false + } + s.timeoutReported = true + return true +} + +func (s *contextState) timeoutExceededMessage() string { s.mu.Lock() defer s.mu.Unlock() + currentTimeout := s.originalTimeout if deadline, ok := s.ctx.Deadline(); ok { - return deadline.Sub(s.testStart) + currentTimeout = deadline.Sub(s.testStart) + } + if s.deadlineLimit == deadlineLimitGoTestTimeout { + return s.withExtensionReportLocked(fmt.Sprintf("test exceeded go test timeout before test context timeout of %v", reportDuration(s.originalTimeout))) + } + if currentTimeout <= s.originalTimeout { + return s.withExtensionReportLocked(fmt.Sprintf("test exceeded timeout of %v", reportDuration(s.originalTimeout))) + } + if currentTimeout-s.originalTimeout < s.extensionRequestedTotal { + return s.withExtensionReportLocked(fmt.Sprintf( + "test exceeded test context extension cap of %v (originally %v, extensions requested total %v)", + reportDuration(maxTestTimeout()), + reportDuration(s.originalTimeout), + reportDuration(s.extensionRequestedTotal), + )) + } + return s.withExtensionReportLocked(fmt.Sprintf("test exceeded extended timeout of %v (originally %v)", reportDuration(currentTimeout), reportDuration(s.originalTimeout))) +} + +func (s *contextState) withExtensionReportLocked(message string) string { + if report := s.extensionReportLocked(); report != "" { + return message + "\n" + report + } + return message +} + +func (s *contextState) extensionReportLocked() string { + if len(s.extensionGrants) == 0 && s.extensionDenied == 0 { + return "" + } + if len(s.extensionGrants) == 0 { + return contextExtensionDeniedMessage(s.extensionDenied) + } + var total time.Duration + for _, grant := range s.extensionGrants { + total += grant.duration + } + message := fmt.Sprintf("ctx extensions = %d (+%v total)", len(s.extensionGrants), reportDuration(total)) + for i, grant := range s.extensionGrants { + message += fmt.Sprintf("\n %d. +%v after %v", i+1, reportDuration(grant.duration), reportDuration(grant.elapsed)) + } + if s.extensionDenied > 0 { + message += fmt.Sprintf("\n%s", contextExtensionDeniedMessage(s.extensionDenied)) + } + return message +} + +func contextExtensionDeniedMessage(count int) string { + if count == 1 { + return "1 context extension denied" + } + return fmt.Sprintf("%d context extensions denied", count) +} + +func reportDuration(d time.Duration) string { + if d > -time.Millisecond && d < time.Millisecond { + rounded := d.Round(time.Microsecond) + if rounded == 0 { + return "0µs" + } + return rounded.String() } - return s.originalTimeout + return d.Round(time.Millisecond).String() } func effectiveTimeout(customTimeout time.Duration) (timeout time.Duration) { diff --git a/common/testing/testcontext/context_test.go b/common/testing/testcontext/context_test.go index 15215edec7..69c2e3e868 100644 --- a/common/testing/testcontext/context_test.go +++ b/common/testing/testcontext/context_test.go @@ -2,6 +2,8 @@ package testcontext import ( "context" + "fmt" + "strings" "sync" "sync/atomic" "testing" @@ -235,6 +237,21 @@ func TestEnsureRemaining(t *testing.T) { For(t, WithTimeout(100*time.Millisecond)) }) + t.Run("records extension grants", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + For(t, WithTimeout(5*time.Millisecond)) + extension1 := EnsureRemaining(t, 10*time.Millisecond) + require.Equal(t, 5*time.Millisecond, extension1.Granted) + extension2 := EnsureRemaining(t, 20*time.Millisecond) + require.Equal(t, 10*time.Millisecond, extension2.Granted) + + require.Equal(t, []time.Duration{ + 5 * time.Millisecond, + 10 * time.Millisecond, + }, extensionGrants(t)) + }) + }) + t.Run("recognizes older context after repeated extensions", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { original := For(t, WithTimeout(5*time.Millisecond)) @@ -251,6 +268,52 @@ func TestEnsureRemaining(t *testing.T) { }) }) + t.Run("reports extension grant time below millisecond", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + For(t, WithTimeout(5*time.Millisecond)) + timer := time.NewTimer(94 * time.Microsecond) + <-timer.C + extension := EnsureRemaining(t, 10*time.Millisecond) + require.Positive(t, extension.Granted) + + require.Contains(t, extensionReport(t), "after 94µs") + }) + }) + + t.Run("records denied extensions", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + For(t, WithTimeout(5*time.Second)) + extension := EnsureRemaining(t, 10*time.Minute) + require.Positive(t, extension.Granted) + extension = EnsureRemaining(t, 10*time.Minute) + require.Zero(t, extension.Granted) + extension = EnsureRemaining(t, 10*time.Minute) + require.Zero(t, extension.Granted) + + require.Contains(t, extensionReport(t), "2 context extensions denied") + }) + }) + + t.Run("reports only denied extensions", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + For(t, WithTimeout(maxTestTimeout())) + extension := EnsureRemaining(t, 10*time.Minute) + require.Zero(t, extension.Granted) + + require.Equal(t, "1 context extension denied", extensionReport(t)) + }) + }) + + t.Run("does not record sufficient remaining time as denied", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + For(t, WithTimeout(20*time.Millisecond)) + extension := EnsureRemaining(t, 10*time.Millisecond) + require.Zero(t, extension.Granted) + + require.Empty(t, extensionReport(t)) + }) + }) + t.Run("safe concurrent calls", func(t *testing.T) { For(t, WithTimeout(100*time.Millisecond)) @@ -268,3 +331,149 @@ func TestEnsureRemaining(t *testing.T) { require.Zero(t, negative.Load()) }) } + +func TestTimeoutExceededCleanupMessage(t *testing.T) { + t.Run("original timeout", func(t *testing.T) { + tb := newRecordingTB() + + synctest.Test(t, func(t *testing.T) { + ctx := For(tb, WithTimeout(5*time.Second)) + <-ctx.Done() + tb.runCleanups() + }) + + timeout := 5 * time.Second * debug.TimeoutMultiplier + require.Equal(t, fmt.Sprintf("context deadline exceeded: test exceeded timeout of %v", timeout), tb.errors()) + }) + + t.Run("extended timeout", func(t *testing.T) { + tb := newRecordingTB() + + synctest.Test(t, func(t *testing.T) { + ctx := For(tb, WithTimeout(5*time.Second)) + extension := EnsureRemaining(tb, 10*time.Second) + require.Equal(t, 5*time.Second, extension.Granted) + + refreshed := For(tb) + <-refreshed.Done() + require.ErrorIs(t, ctx.Err(), context.DeadlineExceeded, "old context keeps its original deadline") + tb.runCleanups() + }) + + original := 5 * time.Second * debug.TimeoutMultiplier + extended := 10 * time.Second + require.Equal(t, strings.Join([]string{ + fmt.Sprintf("context deadline exceeded: test exceeded extended timeout of %v (originally %v)", extended, original), + "ctx extensions = 1 (+5s total)", + " 1. +5s after 0µs", + }, "\n"), tb.errors()) + }) + + t.Run("extension cap", func(t *testing.T) { + tb := newRecordingTB() + + synctest.Test(t, func(t *testing.T) { + For(tb, WithTimeout(5*time.Second)) + extension := EnsureRemaining(tb, 10*time.Minute) + require.Positive(t, extension.Granted) + + refreshed := For(tb) + <-refreshed.Done() + tb.runCleanups() + }) + + original := 5 * time.Second * debug.TimeoutMultiplier + requested := 10*time.Minute - original + require.Equal(t, + strings.Join([]string{ + fmt.Sprintf("context deadline exceeded: test exceeded test context extension cap of %v (originally %v, extensions requested total %v)", maxTestTimeout(), original, requested), + fmt.Sprintf("ctx extensions = 1 (+%v total)", maxTestTimeout()-original), + fmt.Sprintf(" 1. +%v after 0µs", maxTestTimeout()-original), + }, "\n"), + tb.errors(), + ) + }) +} + +func extensionGrants(tb testing.TB) []time.Duration { + tb.Helper() + + testContexts.Lock() + st, ok := testContexts.byTest[tb] + testContexts.Unlock() + if !ok { + return nil + } + + st.mu.Lock() + defer st.mu.Unlock() + grants := make([]time.Duration, 0, len(st.extensionGrants)) + for _, grant := range st.extensionGrants { + grants = append(grants, grant.duration) + } + return grants +} + +func extensionReport(tb testing.TB) string { + tb.Helper() + + testContexts.Lock() + st, ok := testContexts.byTest[tb] + testContexts.Unlock() + if !ok { + return "" + } + + st.mu.Lock() + defer st.mu.Unlock() + return st.extensionReportLocked() +} + +type recordingTB struct { + testing.TB + mu sync.Mutex + errorMessages []string + cleanups []func() +} + +func newRecordingTB() *recordingTB { + return &recordingTB{} +} + +func (r *recordingTB) Helper() {} +func (r *recordingTB) Name() string { + return "recordingTB" +} + +func (r *recordingTB) Context() context.Context { + return context.Background() +} + +func (r *recordingTB) Cleanup(fn func()) { + r.mu.Lock() + defer r.mu.Unlock() + r.cleanups = append(r.cleanups, fn) +} + +func (r *recordingTB) Errorf(format string, args ...any) { + r.mu.Lock() + defer r.mu.Unlock() + r.errorMessages = append(r.errorMessages, fmt.Sprintf(format, args...)) +} + +func (r *recordingTB) runCleanups() { + r.mu.Lock() + cleanups := r.cleanups + r.cleanups = nil + r.mu.Unlock() + + for i := len(cleanups) - 1; i >= 0; i-- { + cleanups[i]() + } +} + +func (r *recordingTB) errors() string { + r.mu.Lock() + defer r.mu.Unlock() + return strings.Join(r.errorMessages, "\n") +}