diff --git a/common/testing/await/require_ctx_test.go b/common/testing/await/require_ctx_test.go index e2f27f9fcd4..1872ea28571 100644 --- a/common/testing/await/require_ctx_test.go +++ b/common/testing/await/require_ctx_test.go @@ -100,7 +100,7 @@ func TestRequire_PropagatesParentContextValues(t *testing.T) { func TestRequire_SetsTimeoutContextDeadline(t *testing.T) { t.Parallel() - longCtx, cancel := context.WithTimeout(testcontext.New(t), time.Minute) + longCtx, cancel := context.WithTimeout(testcontext.For(t), time.Minute) defer cancel() longDeadline, ok := longCtx.Deadline() require.True(t, ok) @@ -156,7 +156,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("reports timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.New(t) + ctx := testcontext.For(t) tb := newRecordingTB() tb.run(func() { await.Require(ctx, tb, func(t *await.T) { @@ -170,7 +170,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("cancels attempt context on timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.New(t) + ctx := testcontext.For(t) tb := newRecordingTB() tb.run(func() { await.Require(ctx, tb, func(t *await.T) { @@ -190,7 +190,7 @@ func TestRequire_FailureScenarios(t *testing.T) { pollInterval := 100 * time.Millisecond t.Setenv("TEMPORAL_AWAIT_ATTEMPT_TIMEOUT", attemptTimeoutEnv.String()) - ctx := testcontext.New(t) + ctx := testcontext.For(t) var attempts atomic.Int32 var firstAttemptRemaining time.Duration @@ -215,7 +215,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("does not poll again after attempt consumes timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.New(t) + ctx := testcontext.For(t) var attempts atomic.Int32 tb := newRecordingTB() @@ -233,7 +233,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("caps attempt context with parent deadline", func(t *testing.T) { t.Parallel() - parentCtx, cancel := context.WithTimeout(testcontext.New(t), time.Second) + parentCtx, cancel := context.WithTimeout(testcontext.For(t), time.Second) defer cancel() tb := newRecordingTB() @@ -259,7 +259,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("parent context cancellation stops polling", func(t *testing.T) { t.Parallel() - parentCtx, cancel := context.WithCancel(testcontext.New(t)) + parentCtx, cancel := context.WithCancel(testcontext.For(t)) defer cancel() var attempts atomic.Int32 @@ -280,7 +280,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("reports all attempt errors on timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.New(t) + ctx := testcontext.For(t) var attempts atomic.Int32 tb := newRecordingTB() tb.run(func() { @@ -302,7 +302,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("truncates middle attempts when many fail", func(t *testing.T) { t.Parallel() - ctx := testcontext.New(t) + ctx := testcontext.For(t) var attempts atomic.Int32 tb := newRecordingTB() tb.run(func() { @@ -329,7 +329,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("Requiref includes message on timeout", func(t *testing.T) { t.Parallel() - ctx := testcontext.New(t) + ctx := testcontext.For(t) tb := newRecordingTB() tb.run(func() { await.Requiref(ctx, tb, func(t *await.T) { @@ -363,7 +363,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx := testcontext.New(t) + ctx := testcontext.For(t) tb := newRecordingTB() tb.run(func() { await.Require(ctx, tb, func(_ *await.T) { @@ -379,7 +379,7 @@ func TestRequire_FailureScenarios(t *testing.T) { t.Run("does not poll after prior failure", func(t *testing.T) { t.Parallel() - ctx := testcontext.New(t) + ctx := testcontext.For(t) conditionCalled := false tb := newRecordingTB() tb.run(func() { @@ -401,7 +401,7 @@ func TestRequire_SoftDeadlockLogsAndCancels(t *testing.T) { const awaitTimeout = 10 * time.Second - ctx := testcontext.New(t) + ctx := testcontext.For(t) tb := newRecordingTB() start := time.Now() tb.run(func() { @@ -427,7 +427,7 @@ func TestRequire_DeadlockDetected(t *testing.T) { const awaitTimeout = 10 * time.Second - ctx := testcontext.New(t) + ctx := testcontext.For(t) tb := newRecordingTB() start := time.Now() tb.run(func() { @@ -448,7 +448,7 @@ func TestRequire_WaitsForInFlightAttemptOnTimeout(t *testing.T) { t.Parallel() var finished atomic.Bool - ctx := testcontext.New(t) + ctx := testcontext.For(t) tb := newRecordingTB() tb.run(func() { await.Require(ctx, tb, func(t *await.T) { diff --git a/common/testing/await/require_true.go b/common/testing/await/require_true.go index 3183430064e..00d0846c452 100644 --- a/common/testing/await/require_true.go +++ b/common/testing/await/require_true.go @@ -17,7 +17,7 @@ const requireTrueMisuseHint = "do not use test assertions inside the predicate - // side effects in the predicate - use [Require] for these. func RequireTrue(tb testing.TB, condition func() bool, timeout, pollInterval time.Duration) { tb.Helper() - run(testcontext.New(tb), tb, func(t *T) { + run(testcontext.For(tb), tb, func(t *T) { if !condition() { t.Fail() } @@ -28,7 +28,7 @@ func RequireTrue(tb testing.TB, condition func() bool, timeout, pollInterval tim // in the failure message when the condition is not satisfied before the timeout. func RequireTruef(tb testing.TB, condition func() bool, timeout, pollInterval time.Duration, msg string, args ...any) { tb.Helper() - run(testcontext.New(tb), tb, func(t *T) { + run(testcontext.For(tb), tb, func(t *T) { if !condition() { t.Fail() } diff --git a/common/testing/parallelsuite/suite.go b/common/testing/parallelsuite/suite.go index ca326454499..5584e1fd59e 100644 --- a/common/testing/parallelsuite/suite.go +++ b/common/testing/parallelsuite/suite.go @@ -46,7 +46,7 @@ type Suite[T testingSuite] struct { // copySuite creates a fresh suite instance initialized for the given *testing.T. // assertT overrides which TestingT assertions are bound to; nil means use the copy's own guardT. -// ctx overrides the suite's context; nil means use the default (lazy testcontext.New). +// ctx overrides the suite's context; nil means use the default (lazy testcontext.For). // //nolint:revive // ctx is last so callers can pass nil to mean "no override"; SA1012 forbids passing nil as the first ctx arg. func (s *Suite[T]) copySuite(t *testing.T, parallel bool, assertT require.TestingT, ctx context.Context) testingSuite { @@ -88,7 +88,7 @@ func (s *Suite[T]) T() *testing.T { func (s *Suite[T]) Context() context.Context { s.ctxOnce.Do(func() { if s.ctx == nil { - s.ctx = testcontext.New(s.T()) + s.ctx = testcontext.For(s.T()) } }) return s.ctx diff --git a/common/testing/parallelsuite/suite_test.go b/common/testing/parallelsuite/suite_test.go index 60b5cd6914e..4a88822ca3e 100644 --- a/common/testing/parallelsuite/suite_test.go +++ b/common/testing/parallelsuite/suite_test.go @@ -88,9 +88,9 @@ func (s *contextSuite) TestContextHasDeadline() { func (s *contextSuite) TestAwaitUsesSuiteContext() { type key struct{} - testcontext.New(s.T(), testcontext.WithContextDecorator(key{}, func(ctx context.Context) context.Context { + testcontext.AttachDecorator(s.T(), key{}, func(ctx context.Context) context.Context { return context.WithValue(ctx, key{}, "decorated") - })) + }) s.Await(func(s *contextSuite) { s.Equal("decorated", s.Context().Value(key{})) diff --git a/common/testing/testcontext/context.go b/common/testing/testcontext/context.go index 5c13b3461b4..430558e9fa2 100644 --- a/common/testing/testcontext/context.go +++ b/common/testing/testcontext/context.go @@ -30,7 +30,6 @@ var testContexts = contextStore{ type config struct { timeout time.Duration timeoutSet bool - decorators []contextDecorator } type contextDecorator struct { @@ -38,16 +37,21 @@ type contextDecorator struct { decorate func(context.Context) context.Context } -// New returns the test-scoped context for tb. The context is canceled when the -// test ends or when the configured test timeout expires. +// DefaultTimeout returns the effective default timeout for test-scoped contexts. +func DefaultTimeout() time.Duration { + return effectiveTimeout(0) +} + +// For returns the test-scoped context for tb. The context is canceled +// when the test ends or when the configured test timeout expires. // // The first call creates the per-test context and fixes its timeout. Later calls -// may add decorators, but an explicit different timeout fails instead of being -// silently ignored. -func New(tb testing.TB, opts ...Option) context.Context { +// return the same context, but an explicit different timeout fails instead of +// being silently ignored. +func For(tb testing.TB, opts ...Option) context.Context { tb.Helper() - cfg := config{timeout: effectiveTimeout(0)} + cfg := config{timeout: DefaultTimeout()} for _, opt := range opts { opt(&cfg) } @@ -57,7 +61,7 @@ func New(tb testing.TB, opts ...Option) context.Context { return st.context() } -// Option configures the test-scoped context returned by [New]. +// Option configures the test-scoped context returned by [For]. type Option func(*config) // WithTimeout sets a custom timeout for the test-scoped context. @@ -71,15 +75,18 @@ func WithTimeout(timeout time.Duration) Option { } } -// WithContextDecorator applies decorator to the test-scoped context once for key. -// Reusing the same key is a no-op. -func WithContextDecorator[K comparable](key K, decorator func(context.Context) context.Context) Option { - return func(cfg *config) { - cfg.decorators = append(cfg.decorators, contextDecorator{ - key: key, - decorate: decorator, - }) - } +// AttachDecorator applies decorator to the test-scoped context once for key. +// Reusing the same key is a no-op. If the test context does not exist yet, +// AttachDecorator creates it with the default timeout. Call [For] with [WithTimeout] +// first when using a custom timeout. +func AttachDecorator[K comparable](tb testing.TB, key K, decorator func(context.Context) context.Context) { + tb.Helper() + + st := getContextState(tb, DefaultTimeout()) + st.attachDecorator(tb, contextDecorator{ + key: key, + decorate: decorator, + }) } type contextState struct { @@ -114,12 +121,13 @@ func getContextState(tb testing.TB, timeout time.Duration) *contextState { testContexts.byTest[tb] = st tb.Cleanup(func() { + err := st.err() st.cancel() testContexts.Lock() delete(testContexts.byTest, tb) testContexts.Unlock() - if st.err() == context.DeadlineExceeded { - tb.Errorf("Test exceeded timeout of %v", st.timeout) + if err == context.DeadlineExceeded { + tb.Errorf("test exceeded timeout of %v", st.timeout) } }) return st @@ -135,21 +143,27 @@ func (s *contextState) configure(tb testing.TB, cfg config) { tb.Fatalf("testcontext: test context already exists with timeout %v; cannot change it to %v", s.timeout, cfg.timeout) } +} + +func (s *contextState) attachDecorator(tb testing.TB, decorator contextDecorator) { + tb.Helper() + // Decorators may be registered by independent helpers, so apply each keyed // decorator at most once while preserving call order. - for _, decorator := range cfg.decorators { - if decorator.key == nil { - tb.Fatal("testcontext: context decorator key must not be nil") - } - if decorator.decorate == nil { - tb.Fatal("testcontext: context decorator must not be nil") - } - if _, ok := s.decorators[decorator.key]; ok { - continue - } - s.ctx = decorator.decorate(s.ctx) - s.decorators[decorator.key] = struct{}{} + s.mu.Lock() + defer s.mu.Unlock() + + if decorator.key == nil { + tb.Fatal("testcontext: context decorator key must not be nil") + } + if decorator.decorate == nil { + tb.Fatal("testcontext: context decorator must not be nil") + } + if _, ok := s.decorators[decorator.key]; ok { + return } + s.ctx = decorator.decorate(s.ctx) + s.decorators[decorator.key] = struct{}{} } func (s *contextState) context() context.Context { @@ -182,6 +196,6 @@ func effectiveTimeout(customTimeout time.Duration) (timeout time.Duration) { } } - // 3. Default 90 seconds. + // 3. Default timeout. return defaultTimeout } diff --git a/common/testing/testcontext/context_test.go b/common/testing/testcontext/context_test.go index b0d8f5390ca..4fe6955f756 100644 --- a/common/testing/testcontext/context_test.go +++ b/common/testing/testcontext/context_test.go @@ -4,25 +4,47 @@ import ( "context" "sync/atomic" "testing" + "testing/synctest" "time" "github.com/stretchr/testify/require" + "go.temporal.io/server/common/debug" "google.golang.org/grpc/metadata" ) func TestWithTimeout(t *testing.T) { t.Parallel() - ctx := New(t, WithTimeout(time.Second)) - deadline, ok := ctx.Deadline() - require.True(t, ok) - require.WithinDuration(t, time.Now().Add(time.Second), deadline, 50*time.Millisecond) + t.Run("default", func(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + start := time.Now() + ctx := For(t) + deadline, ok := ctx.Deadline() + require.True(t, ok) + require.Equal(t, start.Add(DefaultTimeout()), deadline) + require.Equal(t, 90*time.Second*debug.TimeoutMultiplier, DefaultTimeout()) + }) + }) + + t.Run("custom", func(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + start := time.Now() + ctx := For(t, WithTimeout(time.Second)) + deadline, ok := ctx.Deadline() + require.True(t, ok) + require.Equal(t, start.Add(time.Second), deadline) + }) + }) } func TestNameMetadata(t *testing.T) { t.Parallel() - ctx := New(t) + ctx := For(t) md, ok := metadata.FromOutgoingContext(ctx) require.True(t, ok) require.Equal(t, []string{t.Name()}, md.Get(testNameMetadataKey)) @@ -42,15 +64,17 @@ func TestContextDecorators(t *testing.T) { return context.WithValue(ctx, key{}, "decorated") } - ctx := New(t, WithContextDecorator(key{}, decorator)) + AttachDecorator(t, key{}, decorator) + ctx := For(t) require.Equal(t, "decorated", ctx.Value(key{})) - ctx = New(t, WithContextDecorator(key{}, decorator)) + AttachDecorator(t, key{}, decorator) + ctx = For(t) require.Equal(t, "decorated", ctx.Value(key{})) require.Equal(t, int32(1), calls.Load(), "decorator should only be applied once") }) - t.Run("applied once in single call", func(t *testing.T) { + t.Run("applied once for same key", func(t *testing.T) { t.Parallel() type key struct{} @@ -61,10 +85,9 @@ func TestContextDecorators(t *testing.T) { return context.WithValue(ctx, key{}, "decorated") } - ctx := New(t, - WithContextDecorator(key{}, decorator), - WithContextDecorator(key{}, decorator), - ) + AttachDecorator(t, key{}, decorator) + AttachDecorator(t, key{}, decorator) + ctx := For(t) require.Equal(t, "decorated", ctx.Value(key{})) require.Equal(t, int32(1), calls.Load(), "decorator should only be applied once") @@ -76,14 +99,13 @@ func TestContextDecorators(t *testing.T) { type key1 struct{} type key2 struct{} - ctx := New(t, - WithContextDecorator(key1{}, func(ctx context.Context) context.Context { - return context.WithValue(ctx, key1{}, "one") - }), - WithContextDecorator(key2{}, func(ctx context.Context) context.Context { - return context.WithValue(ctx, key2{}, "two") - }), - ) + AttachDecorator(t, key1{}, func(ctx context.Context) context.Context { + return context.WithValue(ctx, key1{}, "one") + }) + AttachDecorator(t, key2{}, func(ctx context.Context) context.Context { + return context.WithValue(ctx, key2{}, "two") + }) + ctx := For(t) require.Equal(t, "one", ctx.Value(key1{})) require.Equal(t, "two", ctx.Value(key2{})) @@ -94,12 +116,13 @@ func TestContextDecorators(t *testing.T) { type key struct{} - ctx := New(t) + ctx := For(t) require.Nil(t, ctx.Value(key{})) - ctx = New(t, WithContextDecorator(key{}, func(ctx context.Context) context.Context { + AttachDecorator(t, key{}, func(ctx context.Context) context.Context { return context.WithValue(ctx, key{}, "decorated") - })) + }) + ctx = For(t) require.Equal(t, "decorated", ctx.Value(key{})) }) } @@ -109,7 +132,7 @@ func TestCleanupCancelsContext(t *testing.T) { var ctx context.Context t.Run("subtest", func(t *testing.T) { - ctx = New(t) + ctx = For(t) require.NoError(t, ctx.Err()) }) require.ErrorIs(t, ctx.Err(), context.Canceled) @@ -119,18 +142,24 @@ func TestEnvTimeout(t *testing.T) { t.Run("from env", func(t *testing.T) { t.Setenv("TEMPORAL_TEST_TIMEOUT", "10s") - ctx := New(t) - deadline, ok := ctx.Deadline() - require.True(t, ok) - require.WithinDuration(t, time.Now().Add(10*time.Second), deadline, 50*time.Millisecond) + synctest.Test(t, func(t *testing.T) { + start := time.Now() + ctx := For(t) + deadline, ok := ctx.Deadline() + require.True(t, ok) + require.Equal(t, start.Add(10*time.Second), deadline) + }) }) t.Run("custom overrides env", func(t *testing.T) { t.Setenv("TEMPORAL_TEST_TIMEOUT", "10s") - ctx := New(t, WithTimeout(time.Second)) - deadline, ok := ctx.Deadline() - require.True(t, ok) - require.WithinDuration(t, time.Now().Add(time.Second), deadline, 50*time.Millisecond) + synctest.Test(t, func(t *testing.T) { + start := time.Now() + ctx := For(t, WithTimeout(time.Second)) + deadline, ok := ctx.Deadline() + require.True(t, ok) + require.Equal(t, start.Add(time.Second), deadline) + }) }) } diff --git a/tests/testcore/context.go b/tests/testcore/context.go index ec61941da6c..a6812a6c5ed 100644 --- a/tests/testcore/context.go +++ b/tests/testcore/context.go @@ -2,16 +2,12 @@ package testcore import ( "context" - "testing" - "go.temporal.io/server/common/headers" "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/testing/testcontext" ) -type versionHeadersContextKey struct{} - -// NewContext creates a context with default 90-second timeout and RPC headers. +// NewContext creates a context with default timeout and RPC headers. // // NOTE: If you're using testcore.NewEnv, you can use env.Context() directly - it already // includes RPC headers. This function is primarily for legacy tests or creating standalone @@ -24,23 +20,12 @@ func NewContext(parent ...context.Context) context.Context { // Create RPC context derived from parent ctx, _ := rpc.NewContextFromParentWithTimeoutAndVersionHeaders( parent[0], - defaultTestTimeout, + testcontext.DefaultTimeout(), ) return ctx } // Create standalone RPC context - ctx, _ := rpc.NewContextWithTimeoutAndVersionHeaders(defaultTestTimeout) + ctx, _ := rpc.NewContextWithTimeoutAndVersionHeaders(testcontext.DefaultTimeout()) return ctx } - -// setupTestTimeoutWithContext creates a context that will be canceled on timeout, -// and reports the timeout error during cleanup. Returns a context that tests can -// use to be interrupted when timeout occurs. The context includes RPC version headers. -func setupTestTimeoutWithContext(t *testing.T) context.Context { - t.Helper() - return testcontext.New( - t, - testcontext.WithContextDecorator(versionHeadersContextKey{}, headers.SetVersions), - ) -} diff --git a/tests/testcore/test_env.go b/tests/testcore/test_env.go index 0aaa7255fee..2efff0a6dd9 100644 --- a/tests/testcore/test_env.go +++ b/tests/testcore/test_env.go @@ -10,7 +10,6 @@ import ( "strings" "sync" "testing" - "time" "github.com/dgryski/go-farm" "github.com/stretchr/testify/require" @@ -22,12 +21,13 @@ import ( "go.temporal.io/server/common" "go.temporal.io/server/common/authorization" "go.temporal.io/server/common/config" - "go.temporal.io/server/common/debug" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/testing/taskpoller" + "go.temporal.io/server/common/testing/testcontext" "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/testlogger" "go.temporal.io/server/common/testing/testvars" @@ -40,10 +40,7 @@ import ( //go:embed shard_salt.txt var shardSalt string -var ( - _ Env = (*TestEnv)(nil) - defaultTestTimeout = 90 * time.Second * debug.TimeoutMultiplier -) +var _ Env = (*TestEnv)(nil) type Env interface { // T returns the *testing.T. @@ -104,6 +101,8 @@ type dynamicConfigOverride struct { value any } +type versionHeadersContextKey struct{} + // WithDedicatedCluster requests a dedicated (non-shared) cluster for the test. // Use this for tests that have cluster-global side effects. func WithDedicatedCluster() TestOption { @@ -266,6 +265,9 @@ func NewEnv(t *testing.T, opts ...TestOption) *TestEnv { tv = options.testVars(tv) } + // Attach version headers decorator to the test context. + testcontext.AttachDecorator(t, versionHeadersContextKey{}, headers.SetVersions) + env := &TestEnv{ FunctionalTestBase: base, Assertions: require.New(t), @@ -276,7 +278,7 @@ func NewEnv(t *testing.T, opts ...TestOption) *TestEnv { taskPoller: taskpoller.New(t, cluster.FrontendClient(), ns.String()), t: t, tv: tv, - ctx: setupTestTimeoutWithContext(t), + ctx: testcontext.For(t), sdkWorkerTQ: RandomizeStr("tq-" + t.Name()), dedicatedGuard: dedicatedGuard, }