Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions common/testing/await/require_ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestRequire_PropagatesParentContextValues(t *testing.T) {
func TestRequire_SetsTimeoutContextDeadline(t *testing.T) {
t.Parallel()

longCtx, cancel := context.WithTimeout(testcontext.New(t), time.Minute)
longCtx, cancel := context.WithTimeout(testcontext.For(t), time.Minute)
defer cancel()
longDeadline, ok := longCtx.Deadline()
require.True(t, ok)
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("reports timeout", func(t *testing.T) {
t.Parallel()

ctx := testcontext.New(t)
ctx := testcontext.For(t)
tb := newRecordingTB()
tb.run(func() {
await.Require(ctx, tb, func(t *await.T) {
Expand All @@ -170,7 +170,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("cancels attempt context on timeout", func(t *testing.T) {
t.Parallel()

ctx := testcontext.New(t)
ctx := testcontext.For(t)
tb := newRecordingTB()
tb.run(func() {
await.Require(ctx, tb, func(t *await.T) {
Expand All @@ -190,7 +190,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
pollInterval := 100 * time.Millisecond
t.Setenv("TEMPORAL_AWAIT_ATTEMPT_TIMEOUT", attemptTimeoutEnv.String())

ctx := testcontext.New(t)
ctx := testcontext.For(t)
var attempts atomic.Int32
var firstAttemptRemaining time.Duration

Expand All @@ -215,7 +215,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("does not poll again after attempt consumes timeout", func(t *testing.T) {
t.Parallel()

ctx := testcontext.New(t)
ctx := testcontext.For(t)
var attempts atomic.Int32

tb := newRecordingTB()
Expand All @@ -233,7 +233,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("caps attempt context with parent deadline", func(t *testing.T) {
t.Parallel()

parentCtx, cancel := context.WithTimeout(testcontext.New(t), time.Second)
parentCtx, cancel := context.WithTimeout(testcontext.For(t), time.Second)
defer cancel()

tb := newRecordingTB()
Expand All @@ -259,7 +259,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("parent context cancellation stops polling", func(t *testing.T) {
t.Parallel()

parentCtx, cancel := context.WithCancel(testcontext.New(t))
parentCtx, cancel := context.WithCancel(testcontext.For(t))
defer cancel()
var attempts atomic.Int32

Expand All @@ -280,7 +280,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("reports all attempt errors on timeout", func(t *testing.T) {
t.Parallel()

ctx := testcontext.New(t)
ctx := testcontext.For(t)
var attempts atomic.Int32
tb := newRecordingTB()
tb.run(func() {
Expand All @@ -302,7 +302,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("truncates middle attempts when many fail", func(t *testing.T) {
t.Parallel()

ctx := testcontext.New(t)
ctx := testcontext.For(t)
var attempts atomic.Int32
tb := newRecordingTB()
tb.run(func() {
Expand All @@ -329,7 +329,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("Requiref includes message on timeout", func(t *testing.T) {
t.Parallel()

ctx := testcontext.New(t)
ctx := testcontext.For(t)
tb := newRecordingTB()
tb.run(func() {
await.Requiref(ctx, tb, func(t *await.T) {
Expand Down Expand Up @@ -363,7 +363,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx := testcontext.New(t)
ctx := testcontext.For(t)
tb := newRecordingTB()
tb.run(func() {
await.Require(ctx, tb, func(_ *await.T) {
Expand All @@ -379,7 +379,7 @@ func TestRequire_FailureScenarios(t *testing.T) {
t.Run("does not poll after prior failure", func(t *testing.T) {
t.Parallel()

ctx := testcontext.New(t)
ctx := testcontext.For(t)
conditionCalled := false
tb := newRecordingTB()
tb.run(func() {
Expand All @@ -401,7 +401,7 @@ func TestRequire_SoftDeadlockLogsAndCancels(t *testing.T) {

const awaitTimeout = 10 * time.Second

ctx := testcontext.New(t)
ctx := testcontext.For(t)
tb := newRecordingTB()
start := time.Now()
tb.run(func() {
Expand All @@ -427,7 +427,7 @@ func TestRequire_DeadlockDetected(t *testing.T) {

const awaitTimeout = 10 * time.Second

ctx := testcontext.New(t)
ctx := testcontext.For(t)
tb := newRecordingTB()
start := time.Now()
tb.run(func() {
Expand All @@ -448,7 +448,7 @@ func TestRequire_WaitsForInFlightAttemptOnTimeout(t *testing.T) {
t.Parallel()

var finished atomic.Bool
ctx := testcontext.New(t)
ctx := testcontext.For(t)
tb := newRecordingTB()
tb.run(func() {
await.Require(ctx, tb, func(t *await.T) {
Expand Down
4 changes: 2 additions & 2 deletions common/testing/await/require_true.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const requireTrueMisuseHint = "do not use test assertions inside the predicate -
// side effects in the predicate - use [Require] for these.
func RequireTrue(tb testing.TB, condition func() bool, timeout, pollInterval time.Duration) {
tb.Helper()
run(testcontext.New(tb), tb, func(t *T) {
run(testcontext.For(tb), tb, func(t *T) {
if !condition() {
t.Fail()
}
Expand All @@ -28,7 +28,7 @@ func RequireTrue(tb testing.TB, condition func() bool, timeout, pollInterval tim
// in the failure message when the condition is not satisfied before the timeout.
func RequireTruef(tb testing.TB, condition func() bool, timeout, pollInterval time.Duration, msg string, args ...any) {
tb.Helper()
run(testcontext.New(tb), tb, func(t *T) {
run(testcontext.For(tb), tb, func(t *T) {
if !condition() {
t.Fail()
}
Expand Down
4 changes: 2 additions & 2 deletions common/testing/parallelsuite/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type Suite[T testingSuite] struct {

// copySuite creates a fresh suite instance initialized for the given *testing.T.
// assertT overrides which TestingT assertions are bound to; nil means use the copy's own guardT.
// ctx overrides the suite's context; nil means use the default (lazy testcontext.New).
// ctx overrides the suite's context; nil means use the default (lazy testcontext.For).
//
//nolint:revive // ctx is last so callers can pass nil to mean "no override"; SA1012 forbids passing nil as the first ctx arg.
func (s *Suite[T]) copySuite(t *testing.T, parallel bool, assertT require.TestingT, ctx context.Context) testingSuite {
Expand Down Expand Up @@ -88,7 +88,7 @@ func (s *Suite[T]) T() *testing.T {
func (s *Suite[T]) Context() context.Context {
s.ctxOnce.Do(func() {
if s.ctx == nil {
s.ctx = testcontext.New(s.T())
s.ctx = testcontext.For(s.T())
}
})
return s.ctx
Expand Down
4 changes: 2 additions & 2 deletions common/testing/parallelsuite/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ func (s *contextSuite) TestContextHasDeadline() {
func (s *contextSuite) TestAwaitUsesSuiteContext() {
type key struct{}

testcontext.New(s.T(), testcontext.WithContextDecorator(key{}, func(ctx context.Context) context.Context {
testcontext.AttachDecorator(s.T(), key{}, func(ctx context.Context) context.Context {
return context.WithValue(ctx, key{}, "decorated")
}))
})

s.Await(func(s *contextSuite) {
s.Equal("decorated", s.Context().Value(key{}))
Expand Down
78 changes: 46 additions & 32 deletions common/testing/testcontext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,28 @@ var testContexts = contextStore{
type config struct {
timeout time.Duration
timeoutSet bool
decorators []contextDecorator
}

type contextDecorator struct {
key any
decorate func(context.Context) context.Context
}

// New returns the test-scoped context for tb. The context is canceled when the
// test ends or when the configured test timeout expires.
// DefaultTimeout returns the effective default timeout for test-scoped contexts.
func DefaultTimeout() time.Duration {
return effectiveTimeout(0)
}

// For returns the test-scoped context for tb. The context is canceled
// when the test ends or when the configured test timeout expires.
//
// The first call creates the per-test context and fixes its timeout. Later calls
// may add decorators, but an explicit different timeout fails instead of being
// silently ignored.
func New(tb testing.TB, opts ...Option) context.Context {
// return the same context, but an explicit different timeout fails instead of
// being silently ignored.
func For(tb testing.TB, opts ...Option) context.Context {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New wasn't quite right as this doesn't create one if it already exists.

tb.Helper()

cfg := config{timeout: effectiveTimeout(0)}
cfg := config{timeout: DefaultTimeout()}
for _, opt := range opts {
opt(&cfg)
}
Expand All @@ -57,7 +61,7 @@ func New(tb testing.TB, opts ...Option) context.Context {
return st.context()
}

// Option configures the test-scoped context returned by [New].
// Option configures the test-scoped context returned by [For].
type Option func(*config)

// WithTimeout sets a custom timeout for the test-scoped context.
Expand All @@ -71,15 +75,18 @@ func WithTimeout(timeout time.Duration) Option {
}
}

// WithContextDecorator applies decorator to the test-scoped context once for key.
// Reusing the same key is a no-op.
func WithContextDecorator[K comparable](key K, decorator func(context.Context) context.Context) Option {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was awkward as it had to be used together with New.

return func(cfg *config) {
cfg.decorators = append(cfg.decorators, contextDecorator{
key: key,
decorate: decorator,
})
}
// AttachDecorator applies decorator to the test-scoped context once for key.
// Reusing the same key is a no-op. If the test context does not exist yet,
// AttachDecorator creates it with the default timeout. Call [For] with [WithTimeout]
// first when using a custom timeout.
func AttachDecorator[K comparable](tb testing.TB, key K, decorator func(context.Context) context.Context) {
tb.Helper()

st := getContextState(tb, DefaultTimeout())
st.attachDecorator(tb, contextDecorator{
key: key,
decorate: decorator,
})
}

type contextState struct {
Expand Down Expand Up @@ -114,12 +121,13 @@ func getContextState(tb testing.TB, timeout time.Duration) *contextState {
testContexts.byTest[tb] = st

tb.Cleanup(func() {
err := st.err()
st.cancel()
testContexts.Lock()
delete(testContexts.byTest, tb)
testContexts.Unlock()
if st.err() == context.DeadlineExceeded {
tb.Errorf("Test exceeded timeout of %v", st.timeout)
if err == context.DeadlineExceeded {
tb.Errorf("test exceeded timeout of %v", st.timeout)
}
})
return st
Expand All @@ -135,21 +143,27 @@ func (s *contextState) configure(tb testing.TB, cfg config) {
tb.Fatalf("testcontext: test context already exists with timeout %v; cannot change it to %v", s.timeout, cfg.timeout)
}

}

func (s *contextState) attachDecorator(tb testing.TB, decorator contextDecorator) {
tb.Helper()

// Decorators may be registered by independent helpers, so apply each keyed
// decorator at most once while preserving call order.
for _, decorator := range cfg.decorators {
if decorator.key == nil {
tb.Fatal("testcontext: context decorator key must not be nil")
}
if decorator.decorate == nil {
tb.Fatal("testcontext: context decorator must not be nil")
}
if _, ok := s.decorators[decorator.key]; ok {
continue
}
s.ctx = decorator.decorate(s.ctx)
s.decorators[decorator.key] = struct{}{}
s.mu.Lock()
defer s.mu.Unlock()

if decorator.key == nil {
tb.Fatal("testcontext: context decorator key must not be nil")
}
if decorator.decorate == nil {
tb.Fatal("testcontext: context decorator must not be nil")
}
if _, ok := s.decorators[decorator.key]; ok {
return
}
s.ctx = decorator.decorate(s.ctx)
s.decorators[decorator.key] = struct{}{}
}

func (s *contextState) context() context.Context {
Expand Down Expand Up @@ -182,6 +196,6 @@ func effectiveTimeout(customTimeout time.Duration) (timeout time.Duration) {
}
}

// 3. Default 90 seconds.
// 3. Default timeout.
return defaultTimeout
}
Loading
Loading