diff --git a/common/backoff/jitter_test.go b/common/backoff/jitter_test.go index 9b3f5bf62fd..b56f6669ef9 100644 --- a/common/backoff/jitter_test.go +++ b/common/backoff/jitter_test.go @@ -5,21 +5,17 @@ import ( "testing" "time" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( jitterSuite struct { - suite.Suite + parallelsuite.Suite[*jitterSuite] } ) func TestJitterSuite(t *testing.T) { - s := new(jitterSuite) - suite.Run(t, s) -} - -func (s *jitterSuite) SetupSuite() { + parallelsuite.Run(t, new(jitterSuite)) } func (s *jitterSuite) TestJitter_Int64() { diff --git a/common/backoff/retry_test.go b/common/backoff/retry_test.go index 3f4af506cb3..bd542db9bda 100644 --- a/common/backoff/retry_test.go +++ b/common/backoff/retry_test.go @@ -5,26 +5,20 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" "go.temporal.io/api/serviceerror" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( RetrySuite struct { - *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error - suite.Suite + parallelsuite.Suite[*RetrySuite] } someError struct{} ) func TestRetrySuite(t *testing.T) { - suite.Run(t, new(RetrySuite)) -} - -func (s *RetrySuite) SetupTest() { - s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil + parallelsuite.RunLegacySequential(t, new(RetrySuite)) //nolint:staticcheck // SA1019: TestThrottleRetryContext temporarily changes package global retry policy. } func (s *RetrySuite) TestRetrySuccess() { diff --git a/common/backoff/retrypolicy_test.go b/common/backoff/retrypolicy_test.go index 78ee024a282..6ce58752a0c 100644 --- a/common/backoff/retrypolicy_test.go +++ b/common/backoff/retrypolicy_test.go @@ -7,15 +7,13 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" "go.temporal.io/server/common/clock" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( RetryPolicySuite struct { - *require.Assertions // override suite.Suite.Assertions with require.Assertions; this means that s.NotNil(nil) will stop the test, not merely log an error - suite.Suite + parallelsuite.Suite[*RetryPolicySuite] } ) @@ -62,11 +60,7 @@ func ExampleExponentialRetryPolicy_WithMaximumInterval() { } func TestRetryPolicySuite(t *testing.T) { - suite.Run(t, new(RetryPolicySuite)) -} - -func (s *RetryPolicySuite) SetupTest() { - s.Assertions = require.New(s.T()) // Have to define our overridden assertions in the test setup. If we did it earlier, s.T() will return nil + parallelsuite.Run(t, new(RetryPolicySuite)) } func (s *RetryPolicySuite) TestExponentialBackoff() { diff --git a/common/circuitbreaker/circuitbreaker_test.go b/common/circuitbreaker/circuitbreaker_test.go index 47144e8a6d3..7546eab94be 100644 --- a/common/circuitbreaker/circuitbreaker_test.go +++ b/common/circuitbreaker/circuitbreaker_test.go @@ -4,22 +4,19 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/testing/parallelsuite" ) -type TSCBWithDynamicSettingsTestSuite struct { - suite.Suite +type CircuitBreakerSuite struct { + parallelsuite.Suite[*CircuitBreakerSuite] } -func TestTSCBWithDynamicSettings(t *testing.T) { - suite.Run(t, &TSCBWithDynamicSettingsTestSuite{}) +func TestCircuitBreakerSuite(t *testing.T) { + parallelsuite.Run(t, new(CircuitBreakerSuite)) } -func TestBasic(t *testing.T) { - s := assert.New(t) - +func (s *CircuitBreakerSuite) TestBasic() { name := "test-tscb" tscb := NewTwoStepCircuitBreakerWithDynamicSettings(Settings{Name: name}) tscb.UpdateSettings(dynamicconfig.CircuitBreakerSettings{}) @@ -30,9 +27,7 @@ func TestBasic(t *testing.T) { doneFn(true) } -func TestDynamicSettings(t *testing.T) { - s := assert.New(t) - +func (s *CircuitBreakerSuite) TestDynamicSettings() { tscb := NewTwoStepCircuitBreakerWithDynamicSettings(Settings{}) tscb.UpdateSettings(dynamicconfig.CircuitBreakerSettings{}) cb1 := tscb.cb.Load() diff --git a/common/definition/resource_dedup_test.go b/common/definition/resource_dedup_test.go index ab6ea770a45..392c45cfbfd 100644 --- a/common/definition/resource_dedup_test.go +++ b/common/definition/resource_dedup_test.go @@ -4,18 +4,17 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( resourceDeduplicationSuite struct { - suite.Suite + parallelsuite.Suite[*resourceDeduplicationSuite] } ) func TestResourceDeduplicationSuite(t *testing.T) { - s := new(resourceDeduplicationSuite) - suite.Run(t, s) + parallelsuite.Run(t, new(resourceDeduplicationSuite)) } func (s *resourceDeduplicationSuite) TestGenerateKey() { diff --git a/common/headers/caller_info_test.go b/common/headers/caller_info_test.go index 5a58c538212..961ac8b708e 100644 --- a/common/headers/caller_info_test.go +++ b/common/headers/caller_info_test.go @@ -4,24 +4,18 @@ import ( "context" "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" "google.golang.org/grpc/metadata" ) type ( callerInfoSuite struct { - *require.Assertions - suite.Suite + parallelsuite.Suite[*callerInfoSuite] } ) func TestCallerInfoSuite(t *testing.T) { - suite.Run(t, &callerInfoSuite{}) -} - -func (s *callerInfoSuite) SetupTest() { - s.Assertions = require.New(s.T()) + parallelsuite.Run(t, &callerInfoSuite{}) } func (s *callerInfoSuite) TestSetCallerName() { diff --git a/common/headers/version_checker_test.go b/common/headers/version_checker_test.go index 709edd6a568..bfcaff0709e 100644 --- a/common/headers/version_checker_test.go +++ b/common/headers/version_checker_test.go @@ -5,24 +5,18 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" "go.temporal.io/api/serviceerror" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( VersionCheckerSuite struct { - *require.Assertions - suite.Suite + parallelsuite.Suite[*VersionCheckerSuite] } ) func TestVersionCheckerSuite(t *testing.T) { - suite.Run(t, new(VersionCheckerSuite)) -} - -func (s *VersionCheckerSuite) SetupTest() { - s.Assertions = require.New(s.T()) + parallelsuite.Run(t, new(VersionCheckerSuite)) } func (s *VersionCheckerSuite) TestClientSupported() { diff --git a/common/number/number_test.go b/common/number/number_test.go index 7ea4ead12cc..811478b2f0a 100644 --- a/common/number/number_test.go +++ b/common/number/number_test.go @@ -4,32 +4,17 @@ import ( "math/rand" "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( numberSuite struct { - suite.Suite - *require.Assertions + parallelsuite.Suite[*numberSuite] } ) func TestNumberSuite(t *testing.T) { - s := new(numberSuite) - suite.Run(t, s) -} - -func (s *numberSuite) SetupSuite() {} - -func (s *numberSuite) TearDownSuite() {} - -func (s *numberSuite) SetupTest() { - s.Assertions = require.New(s.T()) -} - -func (s *numberSuite) TearDownTest() { - + parallelsuite.Run(t, new(numberSuite)) } func (s *numberSuite) TestInt() { diff --git a/common/persistence/versionhistory/version_history_test.go b/common/persistence/versionhistory/version_history_test.go index 6d9378f775d..c36051dc08a 100644 --- a/common/persistence/versionhistory/version_history_test.go +++ b/common/persistence/versionhistory/version_history_test.go @@ -3,30 +3,28 @@ package versionhistory import ( "testing" - "github.com/stretchr/testify/suite" "go.temporal.io/api/serviceerror" historyspb "go.temporal.io/server/api/history/v1" "go.temporal.io/server/common" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( versionHistorySuite struct { - suite.Suite + parallelsuite.Suite[*versionHistorySuite] } versionHistoriesSuite struct { - suite.Suite + parallelsuite.Suite[*versionHistoriesSuite] } ) func TestVersionHistorySuite(t *testing.T) { - s := new(versionHistorySuite) - suite.Run(t, s) + parallelsuite.Run(t, new(versionHistorySuite)) } func TestVersionHistoriesSuite(t *testing.T) { - s := new(versionHistoriesSuite) - suite.Run(t, s) + parallelsuite.Run(t, new(versionHistoriesSuite)) } func (s *versionHistorySuite) TestDuplicateUntilLCAItem_Success() { diff --git a/common/predicates/and_test.go b/common/predicates/and_test.go index 7078e60e71c..08a27dcce96 100644 --- a/common/predicates/and_test.go +++ b/common/predicates/and_test.go @@ -3,24 +3,17 @@ package predicates import ( "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( andSuite struct { - suite.Suite - *require.Assertions + parallelsuite.Suite[*andSuite] } ) func TestAndSuite(t *testing.T) { - s := new(andSuite) - suite.Run(t, s) -} - -func (s *andSuite) SetupTest() { - s.Assertions = require.New(s.T()) + parallelsuite.Run(t, new(andSuite)) } func (s *andSuite) TestAnd_Normal() { diff --git a/common/predicates/empty_test.go b/common/predicates/empty_test.go index fcc7e34ca9c..0f3c1f77052 100644 --- a/common/predicates/empty_test.go +++ b/common/predicates/empty_test.go @@ -3,49 +3,40 @@ package predicates import ( "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( emptySuite struct { - suite.Suite - *require.Assertions - - emtpy Predicate[int] + parallelsuite.Suite[*emptySuite] } ) func TestNoneSuite(t *testing.T) { - s := new(emptySuite) - suite.Run(t, s) -} - -func (s *emptySuite) SetupTest() { - s.Assertions = require.New(s.T()) - - s.emtpy = Empty[int]() + parallelsuite.Run(t, new(emptySuite)) } func (s *emptySuite) TestEmpty_Test() { + empty := Empty[int]() for i := 0; i != 10; i++ { - s.False(s.emtpy.Test(i)) + s.False(empty.Test(i)) } } func (s *emptySuite) TestEmpty_Equals() { - s.True(s.emtpy.Equals(s.emtpy)) - s.True(s.emtpy.Equals(Empty[int]())) + empty := Empty[int]() + s.True(empty.Equals(empty)) + s.True(empty.Equals(Empty[int]())) - s.False(s.emtpy.Equals(newTestPredicate(1, 2, 3))) - s.False(s.emtpy.Equals(And[int]( + s.False(empty.Equals(newTestPredicate(1, 2, 3))) + s.False(empty.Equals(And[int]( newTestPredicate(1, 2, 3), newTestPredicate(2, 3, 4), ))) - s.False(s.emtpy.Equals(Or[int]( + s.False(empty.Equals(Or[int]( newTestPredicate(1, 2, 3), newTestPredicate(4, 5, 6), ))) - s.False(s.emtpy.Equals(Not[int](newTestPredicate(1, 2, 3)))) - s.False(s.emtpy.Equals(Universal[int]())) + s.False(empty.Equals(Not[int](newTestPredicate(1, 2, 3)))) + s.False(empty.Equals(Universal[int]())) } diff --git a/common/predicates/not_test.go b/common/predicates/not_test.go index a0a131b89eb..ccd39786a23 100644 --- a/common/predicates/not_test.go +++ b/common/predicates/not_test.go @@ -3,24 +3,17 @@ package predicates import ( "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( notSuite struct { - suite.Suite - *require.Assertions + parallelsuite.Suite[*notSuite] } ) func TestNotSuite(t *testing.T) { - s := new(notSuite) - suite.Run(t, s) -} - -func (s *notSuite) SetupTest() { - s.Assertions = require.New(s.T()) + parallelsuite.Run(t, new(notSuite)) } func (s *notSuite) TestNot_Test() { diff --git a/common/predicates/or_test.go b/common/predicates/or_test.go index a9a17a69d67..36d55dcfede 100644 --- a/common/predicates/or_test.go +++ b/common/predicates/or_test.go @@ -3,24 +3,17 @@ package predicates import ( "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( orSuite struct { - suite.Suite - *require.Assertions + parallelsuite.Suite[*orSuite] } ) func TestOrSuite(t *testing.T) { - s := new(orSuite) - suite.Run(t, s) -} - -func (s *orSuite) SetupTest() { - s.Assertions = require.New(s.T()) + parallelsuite.Run(t, new(orSuite)) } func (s *orSuite) TestOr_Normal() { diff --git a/common/predicates/universal_test.go b/common/predicates/universal_test.go index 08cc9147908..b1d598d7af6 100644 --- a/common/predicates/universal_test.go +++ b/common/predicates/universal_test.go @@ -4,49 +4,40 @@ import ( "math/rand" "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( universalSuite struct { - suite.Suite - *require.Assertions - - universal Predicate[int] + parallelsuite.Suite[*universalSuite] } ) func TestUniversalSuite(t *testing.T) { - s := new(universalSuite) - suite.Run(t, s) -} - -func (s *universalSuite) SetupTest() { - s.Assertions = require.New(s.T()) - - s.universal = Universal[int]() + parallelsuite.Run(t, new(universalSuite)) } func (s *universalSuite) TestUniversal_Test() { + universal := Universal[int]() for i := 0; i != 10; i++ { - s.True(s.universal.Test(rand.Int())) + s.True(universal.Test(rand.Int())) } } func (s *universalSuite) TestUniversal_Equals() { - s.True(s.universal.Equals(s.universal)) - s.True(s.universal.Equals(Universal[int]())) + universal := Universal[int]() + s.True(universal.Equals(universal)) + s.True(universal.Equals(Universal[int]())) - s.False(s.universal.Equals(newTestPredicate(1, 2, 3))) - s.False(s.universal.Equals(And[int]( + s.False(universal.Equals(newTestPredicate(1, 2, 3))) + s.False(universal.Equals(And[int]( newTestPredicate(1, 2, 3), newTestPredicate(2, 3, 4), ))) - s.False(s.universal.Equals(Or[int]( + s.False(universal.Equals(Or[int]( newTestPredicate(1, 2, 3), newTestPredicate(4, 5, 6), ))) - s.False(s.universal.Equals(Not[int](newTestPredicate(1, 2, 3)))) - s.False(s.universal.Equals(Empty[int]())) + s.False(universal.Equals(Not[int](newTestPredicate(1, 2, 3)))) + s.False(universal.Equals(Empty[int]())) } diff --git a/common/primitives/timestamp/parse_duration_test.go b/common/primitives/timestamp/parse_duration_test.go index 02d89ba54fb..5f8adb5d4e7 100644 --- a/common/primitives/timestamp/parse_duration_test.go +++ b/common/primitives/timestamp/parse_duration_test.go @@ -4,15 +4,15 @@ import ( "testing" "time" - "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/parallelsuite" ) type ParseDurationSuite struct { - suite.Suite + parallelsuite.Suite[*ParseDurationSuite] } func TestParseDurationSuite(t *testing.T) { - suite.Run(t, new(ParseDurationSuite)) + parallelsuite.Run(t, new(ParseDurationSuite)) } func (s *ParseDurationSuite) TestParseDuration() { diff --git a/service/history/api/resetworkflow/api_test.go b/service/history/api/resetworkflow/api_test.go index f751abe3438..59dd0bd3b11 100644 --- a/service/history/api/resetworkflow/api_test.go +++ b/service/history/api/resetworkflow/api_test.go @@ -3,19 +3,18 @@ package resetworkflow import ( "testing" - "github.com/stretchr/testify/suite" enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( resetWorkflowSuite struct { - suite.Suite + parallelsuite.Suite[*resetWorkflowSuite] } ) func TestResetWorkflowSuite(t *testing.T) { - s := new(resetWorkflowSuite) - suite.Run(t, s) + parallelsuite.Run(t, new(resetWorkflowSuite)) } func (s *resetWorkflowSuite) TestGetResetReapplyExcludeTypes() { diff --git a/service/worker/batcher/activities_test.go b/service/worker/batcher/activities_test.go index b33e7407b17..e1a48a92165 100644 --- a/service/worker/batcher/activities_test.go +++ b/service/worker/batcher/activities_test.go @@ -11,7 +11,6 @@ import ( "unicode" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" batchpb "go.temporal.io/api/batch/v1" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" @@ -31,30 +30,24 @@ import ( "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/quotas" "go.temporal.io/server/common/testing/mockapi/workflowservicemock/v1" + "go.temporal.io/server/common/testing/parallelsuite" "go.uber.org/mock/gomock" ) type activitiesSuite struct { - suite.Suite - testsuite.WorkflowTestSuite - - controller *gomock.Controller - - mockFrontendClient *workflowservicemock.MockWorkflowServiceClient + parallelsuite.Suite[*activitiesSuite] } -func (s *activitiesSuite) SetupTest() { - s.controller = gomock.NewController(s.T()) - - s.mockFrontendClient = workflowservicemock.NewMockWorkflowServiceClient(s.controller) +func TestActivitiesSuite(t *testing.T) { + parallelsuite.Run(t, new(activitiesSuite)) } -func TestActivitiesSuite(t *testing.T) { - suite.Run(t, new(activitiesSuite)) +func (s *activitiesSuite) newMockFrontendClient() *workflowservicemock.MockWorkflowServiceClient { + return workflowservicemock.NewMockWorkflowServiceClient(gomock.NewController(s.T())) } func (s *activitiesSuite) TestTaskTimeoutContext() { - s.Run("no parent deadline applies default timeout", func() { + s.Run("no parent deadline applies default timeout", func(s *activitiesSuite) { ctx, cancel := taskTimeoutContext(context.Background()) defer cancel() @@ -63,7 +56,7 @@ func (s *activitiesSuite) TestTaskTimeoutContext() { s.InDelta(defaultTaskTimeout, time.Until(deadline), float64(time.Second)) }) - s.Run("longer parent deadline is shortened to default timeout", func() { + s.Run("longer parent deadline is shortened to default timeout", func(s *activitiesSuite) { parent, parentCancel := context.WithTimeout(context.Background(), defaultTaskTimeout+time.Hour) defer parentCancel() @@ -75,7 +68,7 @@ func (s *activitiesSuite) TestTaskTimeoutContext() { s.InDelta(defaultTaskTimeout, time.Until(deadline), float64(time.Second)) }) - s.Run("shorter parent deadline is preserved", func() { + s.Run("shorter parent deadline is preserved", func(s *activitiesSuite) { shorter := defaultTaskTimeout - 5*time.Second parent, parentCancel := context.WithTimeout(context.Background(), shorter) defer parentCancel() @@ -156,18 +149,19 @@ func (s *activitiesSuite) TestGetLastWorkflowTaskEventID() { }, } for _, tt := range tests { - s.Run(tt.name, func() { + s.Run(tt.name, func(s *activitiesSuite) { + mockFrontendClient := s.newMockFrontendClient() ctx := context.Background() slices.Reverse(tt.history.Events) workflowExecution := &commonpb.WorkflowExecution{} - s.mockFrontendClient.EXPECT().GetWorkflowExecutionHistoryReverse(ctx, gomock.Any()).Return( + mockFrontendClient.EXPECT().GetWorkflowExecutionHistoryReverse(ctx, gomock.Any()).Return( &workflowservice.GetWorkflowExecutionHistoryReverseResponse{History: tt.history, NextPageToken: nil}, nil) - gotWorkflowTaskEventID, err := getLastWorkflowTaskEventID(ctx, namespaceStr, workflowExecution, s.mockFrontendClient, log.NewTestLogger()) + gotWorkflowTaskEventID, err := getLastWorkflowTaskEventID(ctx, namespaceStr, workflowExecution, mockFrontendClient, log.NewTestLogger()) s.Equal(tt.wantErr, err != nil) s.Equal(tt.wantWorkflowTaskEventID, gotWorkflowTaskEventID) if tt.wantErr { var appErr *temporal.ApplicationError - s.Require().ErrorAs(err, &appErr, "error should be an ApplicationError") + s.ErrorAs(err, &appErr, "error should be an ApplicationError") s.True(appErr.NonRetryable(), "error should be non-retryable") s.Equal("NoWorkflowTaskFound", appErr.Type(), "error type should be NoWorkflowTaskFound") } @@ -216,16 +210,17 @@ func (s *activitiesSuite) TestGetFirstWorkflowTaskEventID() { }, } for _, tt := range tests { - s.Run(tt.name, func() { + s.Run(tt.name, func(s *activitiesSuite) { + mockFrontendClient := s.newMockFrontendClient() ctx := context.Background() - s.mockFrontendClient.EXPECT().GetWorkflowExecutionHistory(ctx, gomock.Any()).Return( + mockFrontendClient.EXPECT().GetWorkflowExecutionHistory(ctx, gomock.Any()).Return( &workflowservice.GetWorkflowExecutionHistoryResponse{History: tt.history, NextPageToken: nil}, nil) - gotWorkflowTaskEventID, err := getFirstWorkflowTaskEventID(ctx, namespaceStr, &workflowExecution, s.mockFrontendClient, log.NewTestLogger()) + gotWorkflowTaskEventID, err := getFirstWorkflowTaskEventID(ctx, namespaceStr, &workflowExecution, mockFrontendClient, log.NewTestLogger()) s.Equal(tt.wantErr, err != nil) s.Equal(tt.wantWorkflowTaskEventID, gotWorkflowTaskEventID) if tt.wantErr { var appErr *temporal.ApplicationError - s.Require().ErrorAs(err, &appErr, "error should be an ApplicationError") + s.ErrorAs(err, &appErr, "error should be an ApplicationError") s.True(appErr.NonRetryable(), "error should be non-retryable") s.Equal("NoWorkflowTaskFound", appErr.Type(), "error type should be NoWorkflowTaskFound") } @@ -328,8 +323,9 @@ func (s *activitiesSuite) TestGetResetPoint() { }, } for _, tt := range tests { - s.Run(tt.name, func() { - s.mockFrontendClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), gomock.Any()).Return( + s.Run(tt.name, func(s *activitiesSuite) { + mockFrontendClient := s.newMockFrontendClient() + mockFrontendClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), gomock.Any()).Return( &workflowservice.DescribeWorkflowExecutionResponse{ WorkflowExecutionInfo: &workflowpb.WorkflowExecutionInfo{ AutoResetPoints: &workflowpb.ResetPoints{ @@ -343,7 +339,7 @@ func (s *activitiesSuite) TestGetResetPoint() { WorkflowId: "wfid", RunId: "run1", } - id, err := getResetPoint(ctx, ns, execution, s.mockFrontendClient, tt.buildId, tt.currentRunOnly) + id, err := getResetPoint(ctx, ns, execution, mockFrontendClient, tt.buildId, tt.currentRunOnly) s.Equal(tt.wantErr, err != nil) s.Equal(tt.wantWorkflowTaskEventID, id) if tt.wantSetRunId != "" { @@ -404,7 +400,7 @@ func (s *activitiesSuite) TestAdjustQueryBatchTypeEnum() { }, } for _, testRun := range tests { - s.Run(testRun.name, func() { + s.Run(testRun.name, func(s *activitiesSuite) { a := activities{} adjustedQuery := a.adjustQueryBatchTypeEnum(testRun.query, testRun.batchType) s.Equal(testRun.expectedResult, adjustedQuery) @@ -415,7 +411,7 @@ func (s *activitiesSuite) TestAdjustQueryBatchTypeEnum() { func (s *activitiesSuite) TestAdjustQueryAdminBatchType() { a := activities{} - s.Run("Empty query", func() { + s.Run("Empty query", func(s *activitiesSuite) { adminReq := &adminservice.StartAdminBatchOperationRequest{ VisibilityQuery: "", Operation: &adminservice.StartAdminBatchOperationRequest_RefreshTasksOperation{ @@ -426,7 +422,7 @@ func (s *activitiesSuite) TestAdjustQueryAdminBatchType() { s.Empty(adjustedQuery) }) - s.Run("RefreshWorkflowTasks returns query unchanged", func() { + s.Run("RefreshWorkflowTasks returns query unchanged", func(s *activitiesSuite) { adminReq := &adminservice.StartAdminBatchOperationRequest{ VisibilityQuery: "WorkflowType='MyWorkflow'", Identity: "test", @@ -439,7 +435,7 @@ func (s *activitiesSuite) TestAdjustQueryAdminBatchType() { s.Equal("WorkflowType='MyWorkflow'", adjustedQuery) }) - s.Run("RefreshWorkflowTasks with complex query unchanged", func() { + s.Run("RefreshWorkflowTasks with complex query unchanged", func(s *activitiesSuite) { adminReq := &adminservice.StartAdminBatchOperationRequest{ VisibilityQuery: "(WorkflowType='MyWorkflow') OR (WorkflowType='OtherWorkflow')", Operation: &adminservice.StartAdminBatchOperationRequest_RefreshTasksOperation{ @@ -451,7 +447,7 @@ func (s *activitiesSuite) TestAdjustQueryAdminBatchType() { s.Equal("(WorkflowType='MyWorkflow') OR (WorkflowType='OtherWorkflow')", adjustedQuery) }) - s.Run("Nil operation returns query unchanged", func() { + s.Run("Nil operation returns query unchanged", func(s *activitiesSuite) { adminReq := &adminservice.StartAdminBatchOperationRequest{ VisibilityQuery: "WorkflowType='MyWorkflow'", } @@ -462,7 +458,7 @@ func (s *activitiesSuite) TestAdjustQueryAdminBatchType() { func (s *activitiesSuite) TestProcessAdminTask_RefreshWorkflowTasks() { ctx := context.Background() - mockHistoryClient := historyservicemock.NewMockHistoryServiceClient(s.controller) + mockHistoryClient := historyservicemock.NewMockHistoryServiceClient(gomock.NewController(s.T())) a := &activities{ activityDeps: activityDeps{ @@ -512,7 +508,7 @@ func (s *activitiesSuite) TestProcessAdminTask_RefreshWorkflowTasks() { func (s *activitiesSuite) TestProcessAdminTask_RefreshWorkflowTasks_Error() { ctx := context.Background() - mockHistoryClient := historyservicemock.NewMockHistoryServiceClient(s.controller) + mockHistoryClient := historyservicemock.NewMockHistoryServiceClient(gomock.NewController(s.T())) a := &activities{ activityDeps: activityDeps{ @@ -547,7 +543,7 @@ func (s *activitiesSuite) TestProcessAdminTask_RefreshWorkflowTasks_Error() { mockHistoryClient.EXPECT().RefreshWorkflowTasks(gomock.Any(), gomock.Any()).Return(nil, expectedErr) err := a.processAdminTask(ctx, batchOperation, testTask, limiter) - s.Require().Error(err) + s.Error(err) s.Equal(expectedErr, err) } @@ -603,7 +599,7 @@ func (s *activitiesSuite) TestIsNonRetryableError() { } for _, tt := range tests { - s.Run(tt.name, func() { + s.Run(tt.name, func(s *activitiesSuite) { got := isNonRetryableError(tt.err, tt.batchType) s.Equal(tt.want, got) }) @@ -619,9 +615,10 @@ func (s *activitiesSuite) TestStartTaskProcessor_SignalUsesWorkerNamespace() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + mockFrontendClient := s.newMockFrontendClient() a := &activities{ activityDeps: activityDeps{ - FrontendClient: s.mockFrontendClient, + FrontendClient: mockFrontendClient, Logger: log.NewTestLogger(), MetricsHandler: metrics.NoopMetricsHandler, }, @@ -663,7 +660,7 @@ func (s *activitiesSuite) TestStartTaskProcessor_SignalUsesWorkerNamespace() { limiter := quotas.NewRequestRateLimiterAdapter(quotas.NewDefaultOutgoingRateLimiter(func() float64 { return 100 })) // The signal must be executed with the worker's trusted namespace, not the user-supplied one. - s.mockFrontendClient.EXPECT(). + mockFrontendClient.EXPECT(). SignalWorkflowExecution(gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, req *workflowservice.SignalWorkflowExecutionRequest, _ ...any) (*workflowservice.SignalWorkflowExecutionResponse, error) { s.Equal(workerNamespace, req.Namespace, "must use worker namespace, not request namespace") @@ -672,7 +669,7 @@ func (s *activitiesSuite) TestStartTaskProcessor_SignalUsesWorkerNamespace() { taskCh <- testTask - go a.startTaskProcessor(ctx, batchOperation, workerNamespace, taskCh, respCh, limiter, nil, s.mockFrontendClient, metrics.NoopMetricsHandler, log.NewTestLogger()) + go a.startTaskProcessor(ctx, batchOperation, workerNamespace, taskCh, respCh, limiter, nil, mockFrontendClient, metrics.NoopMetricsHandler, log.NewTestLogger()) resp := <-respCh s.NoError(resp.err) @@ -685,9 +682,10 @@ func (s *activitiesSuite) TestStartTaskProcessor_RetryableErrorsDoNotDeadlock() ctx, cancel := context.WithCancel(context.Background()) defer cancel() + mockFrontendClient := s.newMockFrontendClient() a := &activities{ activityDeps: activityDeps{ - FrontendClient: s.mockFrontendClient, + FrontendClient: mockFrontendClient, Logger: log.NewTestLogger(), MetricsHandler: metrics.NoopMetricsHandler, }, @@ -707,7 +705,7 @@ func (s *activitiesSuite) TestStartTaskProcessor_RetryableErrorsDoNotDeadlock() } // Every signal fails with a retryable error, forcing the worker down the retry path. - s.mockFrontendClient.EXPECT(). + mockFrontendClient.EXPECT(). SignalWorkflowExecution(gomock.Any(), gomock.Any()). Return(nil, errors.New("transient error")). AnyTimes() @@ -718,7 +716,7 @@ func (s *activitiesSuite) TestStartTaskProcessor_RetryableErrorsDoNotDeadlock() respCh := make(chan taskResponse, 1) limiter := quotas.NewRequestRateLimiterAdapter(quotas.NewDefaultOutgoingRateLimiter(func() float64 { return 1000 })) - go a.startTaskProcessor(ctx, batchOperation, "ns", taskCh, respCh, limiter, nil, s.mockFrontendClient, metrics.NoopMetricsHandler, log.NewTestLogger()) + go a.startTaskProcessor(ctx, batchOperation, "ns", taskCh, respCh, limiter, nil, mockFrontendClient, metrics.NoopMetricsHandler, log.NewTestLogger()) // Feed tasks from a separate goroutine so the test can drain responses concurrently. go func() { @@ -736,7 +734,7 @@ func (s *activitiesSuite) TestStartTaskProcessor_RetryableErrorsDoNotDeadlock() for range numTasks { select { case resp := <-respCh: - s.Require().Error(resp.err) + s.Error(resp.err) case <-time.After(10 * time.Second): s.FailNow("timed out waiting for task response: worker is deadlocked") } @@ -858,7 +856,8 @@ func (s *activitiesSuite) TestProcessWorkflowsWithProactiveFetching_ProcessesAll limiter := quotas.NewRequestRateLimiterAdapter(quotas.NewDefaultOutgoingRateLimiter(func() float64 { return 10000 })) // Run inside an activity environment so the coordinator's RecordHeartbeat call is valid. - env := s.NewTestActivityEnvironment() + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestActivityEnvironment() runner := func(ctx context.Context) (HeartBeatDetails, error) { return a.processWorkflowsWithProactiveFetching( ctx, config, fakeWorker, limiter, mockSdk, metrics.NoopMetricsHandler, log.NewTestLogger(), HeartBeatDetails{}, @@ -902,6 +901,6 @@ func (s *activitiesSuite) TestProcessAdminTask_UnknownOperation() { limiter := quotas.NewRequestRateLimiterAdapter(quotas.NewDefaultOutgoingRateLimiter(func() float64 { return 100 })) err := a.processAdminTask(ctx, batchOperation, testTask, limiter) - s.Require().Error(err) + s.Error(err) s.Contains(err.Error(), "unknown admin batch type") } diff --git a/service/worker/batcher/workflow_test.go b/service/worker/batcher/workflow_test.go index 1104d28224b..cd690fdb6e3 100644 --- a/service/worker/batcher/workflow_test.go +++ b/service/worker/batcher/workflow_test.go @@ -5,47 +5,42 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" batchpb "go.temporal.io/api/batch/v1" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/testsuite" batchspb "go.temporal.io/server/api/batch/v1" - "go.uber.org/mock/gomock" + "go.temporal.io/server/common/testing/parallelsuite" ) type batcherSuite struct { - suite.Suite - testsuite.WorkflowTestSuite - controller *gomock.Controller - env *testsuite.TestWorkflowEnvironment + parallelsuite.Suite[*batcherSuite] } func TestBatcherSuite(t *testing.T) { - suite.Run(t, new(batcherSuite)) + parallelsuite.Run(t, new(batcherSuite)) } -func (s *batcherSuite) SetupTest() { - s.controller = gomock.NewController(s.T()) - s.env = s.WorkflowTestSuite.NewTestWorkflowEnvironment() - s.env.RegisterWorkflow(BatchWorkflowProtobuf) -} - -func (s *batcherSuite) TearDownTest() { - s.controller.Finish() - s.env.AssertExpectations(s.T()) +func (s *batcherSuite) newTestWorkflowEnvironment() *testsuite.TestWorkflowEnvironment { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(BatchWorkflowProtobuf) + return env } func (s *batcherSuite) TestBatchWorkflow_ValidParams_Query_Protobuf() { + env := s.newTestWorkflowEnvironment() + defer env.AssertExpectations(s.T()) + var ac *activities - s.env.OnActivity(ac.BatchActivityWithProtobuf, mock.Anything, mock.Anything).Return(HeartBeatDetails{ + env.OnActivity(ac.BatchActivityWithProtobuf, mock.Anything, mock.Anything).Return(HeartBeatDetails{ SuccessCount: 42, ErrorCount: 27, }, nil) - s.env.OnUpsertMemo(mock.Anything).Run(func(args mock.Arguments) { + env.OnUpsertMemo(mock.Anything).Run(func(args mock.Arguments) { memo, ok := args.Get(0).(map[string]any) - s.Require().True(ok) + s.True(ok) s.Equal(map[string]any{ "batch_operation_stats": BatchOperationStats{ NumSuccess: 42, @@ -53,7 +48,7 @@ func (s *batcherSuite) TestBatchWorkflow_ValidParams_Query_Protobuf() { }, }, memo) }).Once() - s.env.ExecuteWorkflow(BatchWorkflowProtobuf, &batchspb.BatchOperationInput{ + env.ExecuteWorkflow(BatchWorkflowProtobuf, &batchspb.BatchOperationInput{ Request: &workflowservice.StartBatchOperationRequest{ JobId: uuid.NewString(), Operation: &workflowservice.StartBatchOperationRequest_TerminationOperation{ @@ -65,19 +60,22 @@ func (s *batcherSuite) TestBatchWorkflow_ValidParams_Query_Protobuf() { }, BatchType: enumspb.BATCH_OPERATION_TYPE_TERMINATE, }) - err := s.env.GetWorkflowError() - s.Require().NoError(err) + err := env.GetWorkflowError() + s.NoError(err) } func (s *batcherSuite) TestBatchWorkflow_ValidParams_Executions_Protobuf() { + env := s.newTestWorkflowEnvironment() + defer env.AssertExpectations(s.T()) + var ac *activities - s.env.OnActivity(ac.BatchActivityWithProtobuf, mock.Anything, mock.Anything).Return(HeartBeatDetails{ + env.OnActivity(ac.BatchActivityWithProtobuf, mock.Anything, mock.Anything).Return(HeartBeatDetails{ SuccessCount: 42, ErrorCount: 27, }, nil) - s.env.OnUpsertMemo(mock.Anything).Run(func(args mock.Arguments) { + env.OnUpsertMemo(mock.Anything).Run(func(args mock.Arguments) { memo, ok := args.Get(0).(map[string]any) - s.Require().True(ok) + s.True(ok) s.Equal(map[string]any{ "batch_operation_stats": BatchOperationStats{ NumSuccess: 42, @@ -85,7 +83,7 @@ func (s *batcherSuite) TestBatchWorkflow_ValidParams_Executions_Protobuf() { }, }, memo) }).Once() - s.env.ExecuteWorkflow(BatchWorkflowProtobuf, &batchspb.BatchOperationInput{ + env.ExecuteWorkflow(BatchWorkflowProtobuf, &batchspb.BatchOperationInput{ Request: &workflowservice.StartBatchOperationRequest{ JobId: uuid.NewString(), Operation: &workflowservice.StartBatchOperationRequest_TerminationOperation{ @@ -102,6 +100,6 @@ func (s *batcherSuite) TestBatchWorkflow_ValidParams_Executions_Protobuf() { }, BatchType: enumspb.BATCH_OPERATION_TYPE_TERMINATE, }) - err := s.env.GetWorkflowError() - s.Require().NoError(err) + err := env.GetWorkflowError() + s.NoError(err) } diff --git a/service/worker/migration/force_replication_workflow_test.go b/service/worker/migration/force_replication_workflow_test.go index 8cb9e9e9188..96d1581402d 100644 --- a/service/worker/migration/force_replication_workflow_test.go +++ b/service/worker/migration/force_replication_workflow_test.go @@ -14,7 +14,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" namespacepb "go.temporal.io/api/namespace/v1" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/interceptor" @@ -26,13 +25,13 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/testing/mockapi/workflowservicemock/v1" + "go.temporal.io/server/common/testing/parallelsuite" "go.uber.org/mock/gomock" ) type ( ForceReplicationWorkflowTestSuite struct { - suite.Suite - forceReplicationWorkflowFn any + parallelsuite.Suite[*ForceReplicationWorkflowTestSuite] } ) @@ -52,15 +51,12 @@ func TestForceReplicationWorkflowTestSuite(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - s := &ForceReplicationWorkflowTestSuite{ - forceReplicationWorkflowFn: tc.forceReplicationWorkflowFn, - } - suite.Run(t, s) + parallelsuite.Run(t, &ForceReplicationWorkflowTestSuite{}, tc.forceReplicationWorkflowFn) }) } } -func (s *ForceReplicationWorkflowTestSuite) TestForceReplicationWorkflow() { +func (s *ForceReplicationWorkflowTestSuite) TestForceReplicationWorkflow(forceReplicationWorkflowFn any) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) @@ -100,7 +96,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestForceReplicationWorkflow() { env.OnActivity(a.VerifyReplicationTasks, mock.Anything, mock.Anything).Return(verifyReplicationTasksResponse{VerifiedWorkflowCount: 1}, nil).Times(totalPageCount) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil).Times(1) - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, ForceReplicationParams{ + env.ExecuteWorkflow(forceReplicationWorkflowFn, ForceReplicationParams{ Namespace: "test-ns", Query: "", ConcurrentActivityCount: 100, @@ -131,7 +127,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestForceReplicationWorkflow() { s.Equal([]byte(nil), status.PageTokenForRestart) } -func (s *ForceReplicationWorkflowTestSuite) TestContinueAsNew() { +func (s *ForceReplicationWorkflowTestSuite) TestContinueAsNew(forceReplicationWorkflowFn any) { totalPageCount := 4 currentPageCount := 0 testMaxPageCountPerExecution := 2 @@ -185,6 +181,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestContinueAsNew() { // Run the workflow once. We should get a continue as new error. continueAsNewInput, queryStatus := s.testRunForceReplicationForContinueAsNew( + forceReplicationWorkflowFn, mockListWorkflows, ForceReplicationParams{ Namespace: "test-ns", @@ -235,6 +232,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestContinueAsNew() { } func (s *ForceReplicationWorkflowTestSuite) testRunForceReplicationForContinueAsNew( + forceReplicationWorkflowFn any, mockListWorkflows func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error), input ForceReplicationParams, expectContinueAsNew bool, @@ -257,7 +255,7 @@ func (s *ForceReplicationWorkflowTestSuite) testRunForceReplicationForContinueAs // executions of ForceReplication. The SeedReplicationQueueWithUserDataEntries activity will eventually run // once, but we aren't guaranteed that it will run during any given execution of ForceReplication. env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil).Maybe() - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, input) + env.ExecuteWorkflow(forceReplicationWorkflowFn, input) s.True(env.IsWorkflowCompleted()) err := env.GetWorkflowError() @@ -287,7 +285,7 @@ func (s *ForceReplicationWorkflowTestSuite) testRunForceReplicationForContinueAs return continueAsNewParams, status } -func (s *ForceReplicationWorkflowTestSuite) TestInvalidInput() { +func (s *ForceReplicationWorkflowTestSuite) TestInvalidInput(forceReplicationWorkflowFn any) { testSuite := &testsuite.WorkflowTestSuite{} for _, invalidInput := range []ForceReplicationParams{ @@ -301,7 +299,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestInvalidInput() { }, } { env := testSuite.NewTestWorkflowEnvironment() - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, invalidInput) + env.ExecuteWorkflow(forceReplicationWorkflowFn, invalidInput) s.True(env.IsWorkflowCompleted()) err := env.GetWorkflowError() @@ -312,7 +310,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestInvalidInput() { } } -func (s *ForceReplicationWorkflowTestSuite) TestListWorkflowsError() { +func (s *ForceReplicationWorkflowTestSuite) TestListWorkflowsError(forceReplicationWorkflowFn any) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) @@ -327,7 +325,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestListWorkflowsError() { env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, ForceReplicationParams{ + env.ExecuteWorkflow(forceReplicationWorkflowFn, ForceReplicationParams{ Namespace: "test-ns", Query: "", ConcurrentActivityCount: 2, @@ -343,7 +341,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestListWorkflowsError() { env.AssertExpectations(s.T()) } -func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskRetryableError() { +func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskRetryableError(forceReplicationWorkflowFn any) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) @@ -376,7 +374,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskRetryable env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, ForceReplicationParams{ + env.ExecuteWorkflow(forceReplicationWorkflowFn, ForceReplicationParams{ Namespace: "test-ns", Query: "", ConcurrentActivityCount: 2, @@ -392,7 +390,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskRetryable env.AssertExpectations(s.T()) } -func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskNonRetryableError() { +func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskNonRetryableError(forceReplicationWorkflowFn any) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) @@ -430,7 +428,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskNonRetrya env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, ForceReplicationParams{ + env.ExecuteWorkflow(forceReplicationWorkflowFn, ForceReplicationParams{ Namespace: "test-ns", Query: "", ConcurrentActivityCount: 1, @@ -448,7 +446,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestGenerateReplicationTaskNonRetrya env.AssertExpectations(s.T()) } -func (s *ForceReplicationWorkflowTestSuite) TestVerifyReplicationTaskNonRetryableError() { +func (s *ForceReplicationWorkflowTestSuite) TestVerifyReplicationTaskNonRetryableError(forceReplicationWorkflowFn any) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) @@ -487,7 +485,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestVerifyReplicationTaskNonRetryabl env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, ForceReplicationParams{ + env.ExecuteWorkflow(forceReplicationWorkflowFn, ForceReplicationParams{ Namespace: "test-ns", Query: "", ConcurrentActivityCount: 1, @@ -505,7 +503,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestVerifyReplicationTaskNonRetryabl env.AssertExpectations(s.T()) } -func (s *ForceReplicationWorkflowTestSuite) TestTaskQueueReplicationFailure() { +func (s *ForceReplicationWorkflowTestSuite) TestTaskQueueReplicationFailure(forceReplicationWorkflowFn any) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) @@ -524,7 +522,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestTaskQueueReplicationFailure() { temporal.NewNonRetryableApplicationError("namespace is required", "InvalidArgument", nil), ) - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, ForceReplicationParams{ + env.ExecuteWorkflow(forceReplicationWorkflowFn, ForceReplicationParams{ Namespace: "test-ns", Query: "", ConcurrentActivityCount: 2, @@ -548,7 +546,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestTaskQueueReplicationFailure() { s.Equal([]byte(nil), status.PageTokenForRestart) } -func (s *ForceReplicationWorkflowTestSuite) TestVerifyPerIterationExecutions() { +func (s *ForceReplicationWorkflowTestSuite) TestVerifyPerIterationExecutions(forceReplicationWorkflowFn any) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) @@ -606,7 +604,7 @@ func (s *ForceReplicationWorkflowTestSuite) TestVerifyPerIterationExecutions() { // Seed task queue replication activity may or may not run in this execution; allow either. env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil).Maybe() - env.ExecuteWorkflow(s.forceReplicationWorkflowFn, ForceReplicationParams{ + env.ExecuteWorkflow(forceReplicationWorkflowFn, ForceReplicationParams{ Namespace: "test-ns", Query: "", ConcurrentActivityCount: 2, diff --git a/service/worker/scanner/executor/executor_test.go b/service/worker/scanner/executor/executor_test.go index 27bee1bf52d..87b8b9bd2ec 100644 --- a/service/worker/scanner/executor/executor_test.go +++ b/service/worker/scanner/executor/executor_test.go @@ -6,14 +6,14 @@ import ( "testing" "time" - "github.com/stretchr/testify/suite" "go.temporal.io/server/common" "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( ExecutorTestSuite struct { - suite.Suite + parallelsuite.Suite[*ExecutorTestSuite] } testTask struct { next TaskStatus @@ -22,7 +22,7 @@ type ( ) func TestExecutionTestSuite(t *testing.T) { - suite.Run(t, new(ExecutorTestSuite)) + parallelsuite.Run(t, new(ExecutorTestSuite)) } func (s *ExecutorTestSuite) TestStartStop() { diff --git a/service/worker/scanner/scanner_test.go b/service/worker/scanner/scanner_test.go index c6a2eac176f..0cf46411b3c 100644 --- a/service/worker/scanner/scanner_test.go +++ b/service/worker/scanner/scanner_test.go @@ -5,7 +5,6 @@ import ( "sync" "testing" - "github.com/stretchr/testify/suite" "go.temporal.io/sdk/client" "go.temporal.io/server/api/adminservicemock/v1" "go.temporal.io/server/api/historyservicemock/v1" @@ -19,16 +18,17 @@ import ( "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/testing/mocksdk" + "go.temporal.io/server/common/testing/parallelsuite" "go.temporal.io/server/service/worker/scanner/build_ids" "go.uber.org/mock/gomock" ) type scannerTestSuite struct { - suite.Suite + parallelsuite.Suite[*scannerTestSuite] } func TestScanner(t *testing.T) { - suite.Run(t, new(scannerTestSuite)) + parallelsuite.Run(t, new(scannerTestSuite)) } func (s *scannerTestSuite) TestScannerEnabled() { @@ -174,7 +174,7 @@ func (s *scannerTestSuite) TestScannerEnabled() { ExpectedScanners: []expectedScanner{historyScanner, executionScanner, buildIdScavenger}, // TaskQueueScanner is only supported for SQL store }, } { - s.Run(c.Name, func() { + s.Run(c.Name, func(s *scannerTestSuite) { ctrl := gomock.NewController(s.T()) mockSdkClientFactory := sdk.NewMockClientFactory(ctrl) mockSdkClient := mocksdk.NewMockClient(ctrl) diff --git a/service/worker/scanner/workflow_test.go b/service/worker/scanner/workflow_test.go index 1bf24060ce0..f56e76cdc3d 100644 --- a/service/worker/scanner/workflow_test.go +++ b/service/worker/scanner/workflow_test.go @@ -6,7 +6,6 @@ import ( "time" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" "go.temporal.io/sdk/activity" "go.temporal.io/sdk/testsuite" "go.temporal.io/sdk/worker" @@ -14,16 +13,16 @@ import ( p "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/resourcetest" + "go.temporal.io/server/common/testing/parallelsuite" "go.uber.org/mock/gomock" ) type scannerWorkflowTestSuite struct { - suite.Suite - testsuite.WorkflowTestSuite + parallelsuite.Suite[*scannerWorkflowTestSuite] } func TestScannerWorkflowTestSuite(t *testing.T) { - suite.Run(t, new(scannerWorkflowTestSuite)) + parallelsuite.RunLegacySequential(t, new(scannerWorkflowTestSuite)) //nolint:staticcheck // SA1019: suite mutates scanner heartbeat interval package global. } func (s *scannerWorkflowTestSuite) registerWorkflows(env *testsuite.TestWorkflowEnvironment) { @@ -39,7 +38,8 @@ func (s *scannerWorkflowTestSuite) registerActivities(env *testsuite.TestActivit } func (s *scannerWorkflowTestSuite) TestWorkflow() { - env := s.NewTestWorkflowEnvironment() + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() s.registerWorkflows(env) env.OnActivity(taskQueueScavengerActivityName, mock.Anything).Return(nil) env.ExecuteWorkflow(tqScannerWFTypeName) @@ -47,7 +47,8 @@ func (s *scannerWorkflowTestSuite) TestWorkflow() { } func (s *scannerWorkflowTestSuite) TestScavengerActivity() { - env := s.NewTestActivityEnvironment() + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestActivityEnvironment() s.registerActivities(env) controller := gomock.NewController(s.T()) defer controller.Finish() diff --git a/service/worker/scheduler/buffer_test.go b/service/worker/scheduler/buffer_test.go index af48ac60624..5e202ed675c 100644 --- a/service/worker/scheduler/buffer_test.go +++ b/service/worker/scheduler/buffer_test.go @@ -3,13 +3,13 @@ package scheduler import ( "testing" - "github.com/stretchr/testify/suite" enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/server/common/testing/parallelsuite" ) type ( processBufferSuite struct { - suite.Suite + parallelsuite.Suite[*processBufferSuite] } job struct { @@ -30,7 +30,7 @@ func jobIds(jobs []*job) (out []int) { func identity[T any](v T) T { return v } func TestProcessBuffer(t *testing.T) { - suite.Run(t, new(processBufferSuite)) + parallelsuite.Run(t, new(processBufferSuite)) } func (s *processBufferSuite) TestProcessBufferEmpty() { diff --git a/service/worker/scheduler/calendar_test.go b/service/worker/scheduler/calendar_test.go index 1ece799bf1e..36bdca30e30 100644 --- a/service/worker/scheduler/calendar_test.go +++ b/service/worker/scheduler/calendar_test.go @@ -5,23 +5,17 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" schedulepb "go.temporal.io/api/schedule/v1" + "go.temporal.io/server/common/testing/parallelsuite" "google.golang.org/protobuf/types/known/durationpb" ) type calendarSuite struct { - suite.Suite - *require.Assertions + parallelsuite.Suite[*calendarSuite] } func TestCalendar(t *testing.T) { - suite.Run(t, new(calendarSuite)) -} - -func (s *calendarSuite) SetupTest() { - s.Assertions = require.New(s.T()) + parallelsuite.Run(t, new(calendarSuite)) } func (s *calendarSuite) mustCompileCalendarSpec(cal *schedulepb.CalendarSpec, tz *time.Location) *compiledCalendar { diff --git a/service/worker/scheduler/spec_test.go b/service/worker/scheduler/spec_test.go index 2a8a3b338cc..d9d6cec5c19 100644 --- a/service/worker/scheduler/spec_test.go +++ b/service/worker/scheduler/spec_test.go @@ -4,32 +4,23 @@ import ( "testing" "time" - "github.com/stretchr/testify/suite" schedulepb "go.temporal.io/api/schedule/v1" - "go.temporal.io/server/common/testing/protorequire" + "go.temporal.io/server/common/testing/parallelsuite" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" ) type specSuite struct { - suite.Suite - protorequire.ProtoAssertions - - specBuilder *SpecBuilder + parallelsuite.Suite[*specSuite] } func TestSpec(t *testing.T) { - suite.Run(t, new(specSuite)) -} - -func (s *specSuite) SetupTest() { - s.ProtoAssertions = protorequire.New(s.T()) - s.specBuilder = NewSpecBuilder() + parallelsuite.Run(t, new(specSuite)) } func (s *specSuite) checkSequenceRaw(spec *schedulepb.ScheduleSpec, start time.Time, seq ...time.Time) { s.T().Helper() - cs, err := s.specBuilder.NewCompiledSpec(spec) + cs, err := NewSpecBuilder().NewCompiledSpec(spec) s.NoError(err) for _, exp := range seq { next := cs.rawNextTime(start) @@ -40,20 +31,20 @@ func (s *specSuite) checkSequenceRaw(spec *schedulepb.ScheduleSpec, start time.T func (s *specSuite) checkSequenceFull(jitterSeed string, spec *schedulepb.ScheduleSpec, start time.Time, seq ...time.Time) { s.T().Helper() - cs, err := s.specBuilder.NewCompiledSpec(spec) + cs, err := NewSpecBuilder().NewCompiledSpec(spec) s.NoError(err) for _, exp := range seq { result := cs.GetNextTime(jitterSeed, start) if exp.IsZero() { - s.Require().True( + s.True( result.Nominal.IsZero(), "exp %v nominal should be zero, got %v", exp, result.Nominal, ) - s.Require().True(result.Next.IsZero(), "next should be zero") + s.True(result.Next.IsZero(), "next should be zero") break } - s.Require().False(result.Nominal.IsZero()) - s.Require().False(result.Next.IsZero()) + s.False(result.Nominal.IsZero()) + s.False(result.Next.IsZero()) s.Equal(exp, result.Next) start = result.Next } @@ -367,7 +358,7 @@ func (s *specSuite) TestSpecExclude() { } func (s *specSuite) TestExcludeAll() { - cs, err := s.specBuilder.NewCompiledSpec(&schedulepb.ScheduleSpec{ + cs, err := NewSpecBuilder().NewCompiledSpec(&schedulepb.ScheduleSpec{ Interval: []*schedulepb.IntervalSpec{ {Interval: durationpb.New(7 * 24 * time.Hour)}, }, diff --git a/service/worker/scheduler/workflow_test.go b/service/worker/scheduler/workflow_test.go index 34e40992d88..9e7d5c28f0d 100644 --- a/service/worker/scheduler/workflow_test.go +++ b/service/worker/scheduler/workflow_test.go @@ -4,12 +4,12 @@ import ( "context" "errors" "math/rand" + "sync" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" failurepb "go.temporal.io/api/failure/v1" @@ -24,6 +24,7 @@ import ( "go.temporal.io/server/common/payload" "go.temporal.io/server/common/payloads" "go.temporal.io/server/common/searchattribute/sadefs" + "go.temporal.io/server/common/testing/parallelsuite" "go.temporal.io/server/common/testing/protoassert" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -31,32 +32,49 @@ import ( type ( workflowSuite struct { - suite.Suite - testsuite.WorkflowTestSuite - env *testsuite.TestWorkflowEnvironment + parallelsuite.Suite[*workflowSuite] } ) var ( - baseStartTime = time.Date(2022, 6, 1, 0, 0, 0, 0, time.UTC) + baseStartTime = time.Date(2022, 6, 1, 0, 0, 0, 0, time.UTC) + workflowSuiteEnvs sync.Map ) func TestWorkflow(t *testing.T) { - suite.Run(t, new(workflowSuite)) + parallelsuite.RunLegacySequential(t, new(workflowSuite)) //nolint:staticcheck // SA1019: suite mutates scheduler package global tweakable policies. } -func (s *workflowSuite) SetupTest() { - s.env = s.NewTestWorkflowEnvironment() +func (s *workflowSuite) env() *testsuite.TestWorkflowEnvironment { + if env, ok := workflowSuiteEnvs.Load(s.T().Name()); ok { + return env.(*testsuite.TestWorkflowEnvironment) + } + env := s.newTestWorkflowEnvironment() + s.setEnv(env) + return env +} + +func (s *workflowSuite) setEnv(env *testsuite.TestWorkflowEnvironment) { + name := s.T().Name() + if _, ok := workflowSuiteEnvs.Load(name); !ok { + s.T().Cleanup(func() { + if env, ok := workflowSuiteEnvs.LoadAndDelete(name); ok { + env.(*testsuite.TestWorkflowEnvironment).AssertExpectations(s.T()) + } + }) + } + workflowSuiteEnvs.Store(name, env) } -func (s *workflowSuite) AfterTest(suiteName, testName string) { - s.env.AssertExpectations(s.T()) +func (s *workflowSuite) newTestWorkflowEnvironment() *testsuite.TestWorkflowEnvironment { + testSuite := &testsuite.WorkflowTestSuite{} + return testSuite.NewTestWorkflowEnvironment() } // test helpers func (s *workflowSuite) now() time.Time { - return s.env.Now().UTC() // env.Now() returns local time by default, force to UTC + return s.env().Now().UTC() // env.Now() returns local time by default, force to UTC } func (s *workflowSuite) defaultAction(id string) *schedulepb.ScheduleAction { @@ -87,14 +105,14 @@ func (s *workflowSuite) run(sched *schedulepb.Schedule, iterations int) { CurrentTweakablePolicies.IterationsBeforeContinueAsNew = iterations // fixed start time - s.env.SetStartTime(baseStartTime) + s.env().SetStartTime(baseStartTime) // fill this in so callers don't need to if sched.Action == nil { sched.Action = s.defaultAction("myid") } - s.env.ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ + s.env().ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ Schedule: sched, State: &schedulespb.InternalState{ Namespace: "myns", @@ -106,7 +124,7 @@ func (s *workflowSuite) run(sched *schedulepb.Schedule, iterations int) { } func (s *workflowSuite) describe() *schedulespb.DescribeResponse { - encoded, err := s.env.QueryWorkflow(QueryNameDescribe) + encoded, err := s.env().QueryWorkflow(QueryNameDescribe) s.NoError(err) var resp schedulespb.DescribeResponse s.NoError(encoded.Get(&resp)) @@ -125,13 +143,13 @@ func (s *workflowSuite) runningWorkflows() []string { // Low-level mock helpers: func (s *workflowSuite) expectStart(f func(req *schedulespb.StartWorkflowRequest) (*schedulespb.StartWorkflowResponse, error)) *testsuite.MockCallWrapper { - return s.env.OnActivity(new(activities).StartWorkflow, mock.Anything, mock.Anything).Once().Return( + return s.env().OnActivity(new(activities).StartWorkflow, mock.Anything, mock.Anything).Once().Return( func(_ context.Context, req *schedulespb.StartWorkflowRequest) (*schedulespb.StartWorkflowResponse, error) { resp, err := f(req) if resp == nil && err == nil { // fill in defaults so callers can be more concise resp = &schedulespb.StartWorkflowResponse{ RunId: uuid.NewString(), - RealStartTime: timestamppb.New(s.env.Now()), + RealStartTime: timestamppb.New(s.env().Now()), } } @@ -140,21 +158,21 @@ func (s *workflowSuite) expectStart(f func(req *schedulespb.StartWorkflowRequest } func (s *workflowSuite) expectWatch(f func(req *schedulespb.WatchWorkflowRequest) (*schedulespb.WatchWorkflowResponse, error)) *testsuite.MockCallWrapper { - return s.env.OnActivity(new(activities).WatchWorkflow, mock.Anything, mock.Anything).Once().Return( + return s.env().OnActivity(new(activities).WatchWorkflow, mock.Anything, mock.Anything).Once().Return( func(_ context.Context, req *schedulespb.WatchWorkflowRequest) (*schedulespb.WatchWorkflowResponse, error) { return f(req) }) } func (s *workflowSuite) expectCancel(f func(req *schedulespb.CancelWorkflowRequest) error) *testsuite.MockCallWrapper { - return s.env.OnActivity(new(activities).CancelWorkflow, mock.Anything, mock.Anything).Once().Return( + return s.env().OnActivity(new(activities).CancelWorkflow, mock.Anything, mock.Anything).Once().Return( func(_ context.Context, req *schedulespb.CancelWorkflowRequest) error { return f(req) }) } func (s *workflowSuite) expectTerminate(f func(req *schedulespb.TerminateWorkflowRequest) error) *testsuite.MockCallWrapper { - return s.env.OnActivity(new(activities).TerminateWorkflow, mock.Anything, mock.Anything).Once().Return( + return s.env().OnActivity(new(activities).TerminateWorkflow, mock.Anything, mock.Anything).Once().Return( func(_ context.Context, req *schedulespb.TerminateWorkflowRequest) error { return f(req) }) @@ -180,7 +198,7 @@ type runAcrossContinueState struct { func (s *workflowSuite) setupMocksForWorkflows(runs []workflowRun, state *runAcrossContinueState) { // capture this to avoid races between end of one test and start of next - env := s.env + env := s.env() for _, run := range runs { run := run // capture fresh value @@ -188,7 +206,7 @@ func (s *workflowSuite) setupMocksForWorkflows(runs []workflowRun, state *runAcr matchStart := mock.MatchedBy(func(req *schedulespb.StartWorkflowRequest) bool { return req.Request.WorkflowId == run.id }) - s.env.OnActivity(new(activities).StartWorkflow, mock.Anything, matchStart).Times(0).Maybe().Return( + s.env().OnActivity(new(activities).StartWorkflow, mock.Anything, matchStart).Times(0).Maybe().Return( func(_ context.Context, req *schedulespb.StartWorkflowRequest) (*schedulespb.StartWorkflowResponse, error) { if prev, ok := state.started[req.Request.WorkflowId]; ok { s.Failf("multiple starts", "for %s at %s (prev %s)", req.Request.WorkflowId, s.now(), prev) @@ -204,7 +222,7 @@ func (s *workflowSuite) setupMocksForWorkflows(runs []workflowRun, state *runAcr matchShortPoll := mock.MatchedBy(func(req *schedulespb.WatchWorkflowRequest) bool { return req.Execution.WorkflowId == run.id && !req.LongPoll }) - s.env.OnActivity(new(activities).WatchWorkflow, mock.Anything, matchShortPoll).Times(0).Maybe().Return( + s.env().OnActivity(new(activities).WatchWorkflow, mock.Anything, matchShortPoll).Times(0).Maybe().Return( func(_ context.Context, req *schedulespb.WatchWorkflowRequest) (*schedulespb.WatchWorkflowResponse, error) { if s.now().Before(run.end) { return &schedulespb.WatchWorkflowResponse{Status: enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING}, nil @@ -215,7 +233,7 @@ func (s *workflowSuite) setupMocksForWorkflows(runs []workflowRun, state *runAcr matchLongPoll := mock.MatchedBy(func(req *schedulespb.WatchWorkflowRequest) bool { return req.Execution.WorkflowId == run.id && req.LongPoll }) - s.env.OnActivity(new(activities).WatchWorkflow, mock.Anything, matchLongPoll).Times(0).Maybe().AfterFn(func() time.Duration { + s.env().OnActivity(new(activities).WatchWorkflow, mock.Anything, matchLongPoll).Times(0).Maybe().AfterFn(func() time.Duration { // this can be called after end of workflow, use captured env return run.end.Sub(env.Now().UTC()) }).Return(func(_ context.Context, req *schedulespb.WatchWorkflowRequest) (*schedulespb.WatchWorkflowResponse, error) { @@ -223,7 +241,7 @@ func (s *workflowSuite) setupMocksForWorkflows(runs []workflowRun, state *runAcr }) } // catch unexpected starts - s.env.OnActivity(new(activities).StartWorkflow, mock.Anything, mock.Anything).Times(0).Maybe().Return( + s.env().OnActivity(new(activities).StartWorkflow, mock.Anything, mock.Anything).Times(0).Maybe().Return( func(_ context.Context, req *schedulespb.StartWorkflowRequest) (*schedulespb.StartWorkflowResponse, error) { s.Failf("unexpected start", "for %s at %s", req.Request.WorkflowId, s.now()) return nil, nil @@ -241,11 +259,11 @@ func (s *workflowSuite) setupDelayedCallbacks(start time.Time, cbs []delayedCall if delay := cb.at.Sub(start); delay > 0 { if cb.finishTest { cb.f = func() { - s.env.SetCurrentHistoryLength(impossibleHistorySize) // signals workflow loop to exit - state.finished = true // signals test to exit + s.env().SetCurrentHistoryLength(impossibleHistorySize) // signals workflow loop to exit + state.finished = true // signals test to exit } } - s.env.RegisterDelayedCallback(cb.f, delay) + s.env().RegisterDelayedCallback(cb.f, delay) } } } @@ -276,23 +294,23 @@ func (s *workflowSuite) runAcrossContinue( started: make(map[string]time.Time), } for { - s.env = s.NewTestWorkflowEnvironment() - s.env.SetStartTime(startTime) + s.setEnv(s.newTestWorkflowEnvironment()) + s.env().SetStartTime(startTime) s.setupMocksForWorkflows(runs, &state) s.setupDelayedCallbacks(startTime, cbs, &state) s.T().Logf("starting workflow with CAN every %d iterations, start time %s", CurrentTweakablePolicies.IterationsBeforeContinueAsNew, startTime) - s.env.ExecuteWorkflow(SchedulerWorkflow, startArgs) + s.env().ExecuteWorkflow(SchedulerWorkflow, startArgs) s.T().Logf("finished workflow, time is now %s, finished is %v", s.now(), state.finished) - s.True(s.env.IsWorkflowCompleted()) - result := s.env.GetWorkflowError() + s.True(s.env().IsWorkflowCompleted()) + result := s.env().GetWorkflowError() var canErr *workflow.ContinueAsNewError - s.Require().True(errors.As(result, &canErr), "result: %v", result) + s.True(errors.As(result, &canErr), "result: %v", result) - s.env.AssertExpectations(s.T()) + s.env().AssertExpectations(s.T()) if state.finished { break @@ -300,10 +318,10 @@ func (s *workflowSuite) runAcrossContinue( startTime = s.now() startArgs = nil - s.Require().NoError(payloads.Decode(canErr.Input, &startArgs)) + s.NoError(payloads.Decode(canErr.Input, &startArgs)) } // check starts that we actually got - s.Require().Equalf(len(runs), len(state.started), "started %#v", state.started) + s.Equalf(len(runs), len(state.started), "started %#v", state.started) for _, run := range runs { actual := state.started[run.id] inRange := !actual.Before(run.start.Add(-run.startTolerance)) && !actual.After(run.start.Add(run.startTolerance)) @@ -353,8 +371,8 @@ func (s *workflowSuite) TestStart() { Action: action, }, 2) // two iterations to start one workflow: first will sleep, second will start and then sleep again - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestInitialPatch() { @@ -378,8 +396,8 @@ func (s *workflowSuite) TestInitialPatch() { }) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 2 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ Spec: &schedulepb.ScheduleSpec{ Interval: []*schedulepb.IntervalSpec{{ @@ -398,8 +416,8 @@ func (s *workflowSuite) TestInitialPatch() { TriggerImmediately: &schedulepb.TriggerImmediatelyRequest{}, }, }) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestCatchupWindow() { @@ -423,13 +441,13 @@ func (s *workflowSuite) TestCatchupWindow() { s.Equal("myid-2022-06-01T00:17:00Z", req.Request.WorkflowId) return nil, nil }) - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { s.Equal(int64(5), s.describe().Info.MissedCatchupWindow) }, 18*time.Minute) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 2 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ Spec: &schedulepb.ScheduleSpec{ Calendar: []*schedulepb.CalendarSpec{{ @@ -451,18 +469,18 @@ func (s *workflowSuite) TestCatchupWindow() { LastProcessedTime: timestamppb.New(time.Date(2022, 5, 31, 18, 0, 0, 0, time.UTC)), }, }) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestCatchupWindowWhilePaused() { // written using low-level mocks so we can set initial state - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { // should not count any "misses" since we were paused s.Equal(int64(0), s.describe().Info.MissedCatchupWindow) // unpause just to make the test end cleanly - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{Unpause: "go ahead"}) + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{Unpause: "go ahead"}) }, 3*time.Minute) s.expectStart(func(req *schedulespb.StartWorkflowRequest) (*schedulespb.StartWorkflowResponse, error) { s.True(time.Date(2022, 6, 1, 0, 17, 0, 0, time.UTC).Equal(s.now())) @@ -471,8 +489,8 @@ func (s *workflowSuite) TestCatchupWindowWhilePaused() { }) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 3 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ Spec: &schedulepb.ScheduleSpec{ Calendar: []*schedulepb.CalendarSpec{{ @@ -497,8 +515,8 @@ func (s *workflowSuite) TestCatchupWindowWhilePaused() { LastProcessedTime: timestamppb.New(time.Date(2022, 5, 31, 18, 0, 0, 0, time.UTC)), }, }) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError()), s.env.GetWorkflowError()) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError()), s.env().GetWorkflowError()) } func (s *workflowSuite) TestOverlapSkip() { @@ -816,8 +834,8 @@ func (s *workflowSuite) TestOverlapCancel() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_CANCEL_OTHER, }, }, 4) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestOverlapTerminate() { @@ -867,8 +885,8 @@ func (s *workflowSuite) TestOverlapTerminate() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_TERMINATE_OTHER, }, }, 4) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestOverlapAllowAll() { @@ -955,8 +973,8 @@ func (s *workflowSuite) TestFailedStart() { }}, }, }, 4) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestLastCompletionResultAndContinuedFailure() { @@ -1028,8 +1046,8 @@ func (s *workflowSuite) TestLastCompletionResultAndContinuedFailure() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, }, }, 5) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestOnlyStartForAllowAll() { @@ -1064,8 +1082,8 @@ func (s *workflowSuite) TestOnlyStartForAllowAll() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_ALLOW_ALL, }, }, 4) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestPauseOnFailure() { @@ -1089,10 +1107,10 @@ func (s *workflowSuite) TestPauseOnFailure() { }, }, nil }) - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { s.False(s.describe().Schedule.State.Paused) }, 9*time.Minute) - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { desc := s.describe() s.True(desc.Schedule.State.Paused) s.Contains(desc.Schedule.State.Notes, "paused due to workflow failure") @@ -1109,14 +1127,14 @@ func (s *workflowSuite) TestPauseOnFailure() { PauseOnFailure: true, }, }, 3) - s.True(s.env.IsWorkflowCompleted()) + s.True(s.env().IsWorkflowCompleted()) // doesn't end properly since it sleeps forever after pausing } func (s *workflowSuite) TestCompileError() { // written using low-level mocks since it sleeps forever - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { s.Contains(s.describe().Info.InvalidScheduleError, "Month is not in range [1-12]") }, 1*time.Minute) @@ -1151,7 +1169,7 @@ func (s *workflowSuite) TestTriggerImmediate() { at: time.Date(2022, 6, 1, 0, 20, 0, 0, time.UTC), f: func() { // this gets skipped because a scheduled run is still running - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ TriggerImmediately: &schedulepb.TriggerImmediatelyRequest{ ScheduledTime: timestamppb.New(time.Date(2022, 6, 1, 0, 20, 0, 0, time.UTC)), }, @@ -1171,7 +1189,7 @@ func (s *workflowSuite) TestTriggerImmediate() { at: time.Date(2022, 6, 1, 0, 30, 0, 0, time.UTC), f: func() { // this one runs with overridden overlap policy - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ TriggerImmediately: &schedulepb.TriggerImmediatelyRequest{ ScheduledTime: timestamppb.New(time.Date(2022, 6, 1, 0, 30, 0, 0, time.UTC)), OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_ALLOW_ALL, @@ -1236,7 +1254,7 @@ func (s *workflowSuite) TestBackfill() { { at: time.Date(2022, 6, 1, 0, 5, 0, 0, time.UTC), f: func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ BackfillRequest: []*schedulepb.BackfillRequest{{ StartTime: timestamppb.New(time.Date(2022, 5, 31, 0, 0, 0, 0, time.UTC)), EndTime: timestamppb.New(time.Date(2022, 6, 1, 0, 0, 0, 0, time.UTC)), @@ -1301,7 +1319,7 @@ func (s *workflowSuite) TestBackfillInclusiveStartEnd() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_BUFFER_ALL, } - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ BackfillRequest: []*schedulepb.BackfillRequest{triggerBackfill, ignoreBackfill}, }) }, @@ -1362,7 +1380,7 @@ func (s *workflowSuite) TestHugeBackfillAllowAll() { // as a workflow timer, so use an odd interval to force it to be different. at: baseStartTime.Add(time.Minute).Add(time.Duration(i) * 1113 * time.Millisecond), f: func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ BackfillRequest: []*schedulepb.BackfillRequest{{ StartTime: timestamppb.New(base.Add(time.Duration(i*backfillRuns/backfills) * time.Hour)), EndTime: timestamppb.New(base.Add(time.Duration((i+1)*backfillRuns/backfills-1) * time.Hour)), @@ -1428,7 +1446,7 @@ func (s *workflowSuite) TestHugeBackfillBuffer() { delayedCallbacks[i] = delayedCallback{ at: baseStartTime.Add(time.Minute).Add(time.Duration(i) * 1113 * time.Millisecond), f: func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ BackfillRequest: []*schedulepb.BackfillRequest{{ StartTime: timestamppb.New(base.Add(time.Duration(i*backfillRuns/backfills) * time.Hour)), EndTime: timestamppb.New(base.Add(time.Duration((i+1)*backfillRuns/backfills-1) * time.Hour)), @@ -1485,7 +1503,7 @@ func (s *workflowSuite) TestPause() { { at: time.Date(2022, 6, 1, 0, 7, 7, 0, time.UTC), f: func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ Pause: "paused", }) }, @@ -1501,7 +1519,7 @@ func (s *workflowSuite) TestPause() { { at: time.Date(2022, 6, 1, 0, 26, 7, 0, time.UTC), f: func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ Unpause: "go ahead", }) }, @@ -1567,15 +1585,15 @@ func (s *workflowSuite) TestUpdate() { at: time.Date(2022, 6, 1, 0, 9, 5, 0, time.UTC), f: func() { // shouldn't crash - s.env.SignalWorkflow(SignalNameUpdate, nil) - s.env.SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{}) + s.env().SignalWorkflow(SignalNameUpdate, nil) + s.env().SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{}) }, }, { at: time.Date(2022, 6, 1, 0, 9, 7, 0, time.UTC), f: func() { desc := s.describe() - s.env.SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{ + s.env().SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{ ConflictToken: desc.ConflictToken, Schedule: &schedulepb.Schedule{ Spec: &schedulepb.ScheduleSpec{ @@ -1600,7 +1618,7 @@ func (s *workflowSuite) TestUpdate() { at: time.Date(2022, 6, 1, 0, 12, 7, 0, time.UTC), f: func() { desc := s.describe() - s.env.SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{ + s.env().SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{ ConflictToken: desc.ConflictToken + 37, // conflict, should not take effect Schedule: &schedulepb.Schedule{}, }) @@ -1650,7 +1668,7 @@ func (s *workflowSuite) TestUpdateNotRetroactive() { { at: time.Date(2022, 6, 1, 1, 7, 10, 0, time.UTC), f: func() { - s.env.SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{ + s.env().SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{ Schedule: &schedulepb.Schedule{ Spec: &schedulepb.ScheduleSpec{ Interval: []*schedulepb.IntervalSpec{{ @@ -1728,7 +1746,7 @@ func (s *workflowSuite) TestUpdateBetweenNominalAndJitter() { // update after nominal time 03:00:00 but before jittered time 03:37:29 at: time.Date(2022, 6, 1, 3, 22, 10, 0, time.UTC), f: func() { - s.env.SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{ + s.env().SignalWorkflow(SignalNameUpdate, &schedulespb.FullUpdateRequest{ Schedule: &schedulepb.Schedule{ Spec: spec, Action: s.defaultAction("newid"), @@ -1781,7 +1799,7 @@ func (s *workflowSuite) TestSignalBetweenNominalAndJittered() { // signal between nominal and jittered time at: time.Date(2022, 6, 1, 3, 22, 10, 0, time.UTC), f: func() { - s.env.SignalWorkflow(SignalNameRefresh, nil) + s.env().SignalWorkflow(SignalNameRefresh, nil) }, }, { @@ -1834,13 +1852,13 @@ func (s *workflowSuite) TestPauseUnpauseBetweenNominalAndJittered() { { at: time.Date(2022, 6, 1, 3, 20, 0, 0, time.UTC), f: func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{Pause: "paused"}) + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{Pause: "paused"}) }, }, { at: time.Date(2022, 6, 1, 3, 30, 0, 0, time.UTC), f: func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{Unpause: "go ahead"}) + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{Unpause: "go ahead"}) }, }, { @@ -1885,28 +1903,28 @@ func (s *workflowSuite) TestLimitedActions() { return &schedulespb.WatchWorkflowResponse{Status: enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED}, nil }) - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { desc := s.describe() s.Equal(int64(2), desc.Schedule.State.RemainingActions) s.Equal(2, len(desc.Info.FutureActionTimes)) }, 1*time.Minute) - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { desc := s.describe() s.Equal(int64(1), desc.Schedule.State.RemainingActions) s.Equal(1, len(desc.Info.FutureActionTimes)) }, 5*time.Minute) - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { desc := s.describe() s.Equal(int64(0), desc.Schedule.State.RemainingActions) s.Equal(0, len(desc.Info.FutureActionTimes)) s.Equal(1, len(s.runningWorkflows())) }, 7*time.Minute) - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { // hasn't updated yet since we slept past :09 s.Equal(1, len(s.runningWorkflows())) - s.env.SignalWorkflow(SignalNameRefresh, nil) + s.env().SignalWorkflow(SignalNameRefresh, nil) }, 10*time.Minute) - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { s.Equal(0, len(s.runningWorkflows())) }, 10*time.Minute+1*time.Second) @@ -1924,7 +1942,7 @@ func (s *workflowSuite) TestLimitedActions() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_SKIP, }, }, 4) - s.True(s.env.IsWorkflowCompleted()) + s.True(s.env().IsWorkflowCompleted()) // doesn't end properly since it sleeps forever after pausing } @@ -1972,7 +1990,7 @@ func (s *workflowSuite) TestLotsOfIterations() { delayedCallbacks[i] = delayedCallback{ at: callbackTime, f: func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ BackfillRequest: []*schedulepb.BackfillRequest{{ StartTime: timestamppb.New(callBackRangeStartTime), EndTime: timestamppb.New(callBackRangeStartTime.Add(time.Duration(maxRuns) * time.Hour)), @@ -2022,8 +2040,8 @@ func (s *workflowSuite) TestExitScheduleWorkflowWhenNoActions() { }) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 5 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ Spec: &schedulepb.ScheduleSpec{ Interval: []*schedulepb.IntervalSpec{{ @@ -2043,9 +2061,9 @@ func (s *workflowSuite) TestExitScheduleWorkflowWhenNoActions() { ConflictToken: InitialConflictToken, }, }) - s.True(s.env.IsWorkflowCompleted()) - s.False(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) - s.True(s.env.Now().Sub(time.Date(2022, 6, 1, 0, 30, 0, 0, time.UTC)) == CurrentTweakablePolicies.RetentionTime) + s.True(s.env().IsWorkflowCompleted()) + s.False(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) + s.True(s.env().Now().Sub(time.Date(2022, 6, 1, 0, 30, 0, 0, time.UTC)) == CurrentTweakablePolicies.RetentionTime) } func (s *workflowSuite) TestExitScheduleWorkflowWhenNoNextTime() { @@ -2057,8 +2075,8 @@ func (s *workflowSuite) TestExitScheduleWorkflowWhenNoNextTime() { }) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 3 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ Spec: &schedulepb.ScheduleSpec{ Calendar: []*schedulepb.CalendarSpec{{ @@ -2079,17 +2097,17 @@ func (s *workflowSuite) TestExitScheduleWorkflowWhenNoNextTime() { ConflictToken: InitialConflictToken, }, }) - s.True(s.env.IsWorkflowCompleted()) - s.False(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) - s.True(s.env.Now().Sub(time.Date(2022, 6, 1, 1, 0, 0, 0, time.UTC)) == CurrentTweakablePolicies.RetentionTime) + s.True(s.env().IsWorkflowCompleted()) + s.False(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) + s.True(s.env().Now().Sub(time.Date(2022, 6, 1, 1, 0, 0, 0, time.UTC)) == CurrentTweakablePolicies.RetentionTime) } func (s *workflowSuite) TestExitScheduleWorkflowWhenEmpty() { scheduleId := "myschedule" CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 3 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ Action: s.defaultAction("myid"), }, @@ -2101,9 +2119,9 @@ func (s *workflowSuite) TestExitScheduleWorkflowWhenEmpty() { }, }) - s.True(s.env.IsWorkflowCompleted()) - s.False(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) - s.True(s.env.Now().Sub(baseStartTime) == CurrentTweakablePolicies.RetentionTime) + s.True(s.env().IsWorkflowCompleted()) + s.False(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) + s.True(s.env().Now().Sub(baseStartTime) == CurrentTweakablePolicies.RetentionTime) } func (s *workflowSuite) TestCANByIterations() { @@ -2125,8 +2143,8 @@ func (s *workflowSuite) TestCANByIterations() { }).Times(0).Maybe() // this is ignored because we set iters explicitly - s.env.RegisterDelayedCallback(func() { - s.env.SetContinueAsNewSuggested(true) + s.env().RegisterDelayedCallback(func() { + s.env().SetContinueAsNewSuggested(true) }, 5*time.Minute*iters/2-time.Second) s.run(&schedulepb.Schedule{ @@ -2139,8 +2157,8 @@ func (s *workflowSuite) TestCANByIterations() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_ALLOW_ALL, }, }, iters) - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestCANBySuggested() { @@ -2161,8 +2179,8 @@ func (s *workflowSuite) TestCANBySuggested() { return nil, nil }).Times(0).Maybe() - s.env.RegisterDelayedCallback(func() { - s.env.SetContinueAsNewSuggested(true) + s.env().RegisterDelayedCallback(func() { + s.env().SetContinueAsNewSuggested(true) }, 5*time.Minute*iters-time.Second) s.run(&schedulepb.Schedule{ @@ -2175,8 +2193,8 @@ func (s *workflowSuite) TestCANBySuggested() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_ALLOW_ALL, }, }, 0) // 0 means use suggested - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestCANBySuggestedWithSignals() { @@ -2206,13 +2224,13 @@ func (s *workflowSuite) TestCANBySuggestedWithSignals() { return nil, nil }).Times(0).Maybe() - s.env.RegisterDelayedCallback(func() { - s.env.SetContinueAsNewSuggested(true) + s.env().RegisterDelayedCallback(func() { + s.env().SetContinueAsNewSuggested(true) }, suggestCANAt) for _, d := range runs { - s.env.RegisterDelayedCallback(func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().RegisterDelayedCallback(func() { + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ TriggerImmediately: &schedulepb.TriggerImmediatelyRequest{}, }) }, d) @@ -2231,8 +2249,8 @@ func (s *workflowSuite) TestCANBySuggestedWithSignals() { Paused: true, }, }, 0) // 0 means use suggested - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestCANBySignal() { @@ -2253,8 +2271,8 @@ func (s *workflowSuite) TestCANBySignal() { return nil, nil }).Times(0).Maybe() - s.env.RegisterDelayedCallback(func() { - s.env.SignalWorkflow(SignalNameForceCAN, nil) + s.env().RegisterDelayedCallback(func() { + s.env().SignalWorkflow(SignalNameForceCAN, nil) }, 5*time.Minute*iters-time.Second) s.run(&schedulepb.Schedule{ @@ -2267,24 +2285,24 @@ func (s *workflowSuite) TestCANBySignal() { OverlapPolicy: enumspb.SCHEDULE_OVERLAP_POLICY_ALLOW_ALL, }, }, 0) // 0 means use suggested - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) } func (s *workflowSuite) TestMigrateSuccess() { // Mock MigrateSchedule activity to succeed. - s.env.OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Once().Return(nil) + s.env().OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Once().Return(nil) // Enable migration and request it via signal after the first iteration. enableMigration := false - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { enableMigration = true - s.env.SignalWorkflow(SignalNameMigrateToChasm, nil) + s.env().SignalWorkflow(SignalNameMigrateToChasm, nil) }, 1*time.Second) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 100 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { return schedulerWorkflowWithSpecBuilder(ctx, args, NewSpecBuilder(), func() bool { return enableMigration }, func() bool { return true }) }, &schedulespb.StartScheduleArgs{ @@ -2305,15 +2323,15 @@ func (s *workflowSuite) TestMigrateSuccess() { }) // Workflow should complete successfully (not CAN) after migration. - s.True(s.env.IsWorkflowCompleted()) - s.NoError(s.env.GetWorkflowError()) + s.True(s.env().IsWorkflowCompleted()) + s.NoError(s.env().GetWorkflowError()) } func (s *workflowSuite) TestMigrateFailure() { // Mock MigrateSchedule activity to always fail. Migration is retried // each iteration since PendingMigration is persisted in State. migrateCalls := 0 - s.env.OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( + s.env().OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( func(context.Context, *schedulerpb.CreateFromMigrationStateRequest) error { migrateCalls++ return errors.New("migration failed") @@ -2321,21 +2339,21 @@ func (s *workflowSuite) TestMigrateFailure() { // Enable migration and request it via signal after the first iteration. enableMigration := false - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { enableMigration = true - s.env.SignalWorkflow(SignalNameMigrateToChasm, nil) + s.env().SignalWorkflow(SignalNameMigrateToChasm, nil) }, 1*time.Second) // After ~5 iterations (5 hours of simulated time), the workflow should // still be running -- migration failed but the scheduler continues. stillRunning := false - s.env.RegisterDelayedCallback(func() { - stillRunning = !s.env.IsWorkflowCompleted() + s.env().RegisterDelayedCallback(func() { + stillRunning = !s.env().IsWorkflowCompleted() }, 5*time.Hour) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 100 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { return schedulerWorkflowWithSpecBuilder(ctx, args, NewSpecBuilder(), func() bool { return enableMigration }, func() bool { return true }) }, &schedulespb.StartScheduleArgs{ @@ -2362,16 +2380,16 @@ func (s *workflowSuite) TestMigrateFailure() { // Verify PendingMigration is persisted in CAN state. var canErr *workflow.ContinueAsNewError - s.Require().ErrorAs(s.env.GetWorkflowError(), &canErr) + s.ErrorAs(s.env().GetWorkflowError(), &canErr) var canArgs schedulespb.StartScheduleArgs - s.Require().NoError(payloads.Decode(canErr.Input, &canArgs)) + s.NoError(payloads.Decode(canErr.Input, &canArgs)) s.True(canArgs.State.PendingMigration, "PendingMigration should be set in CAN state") } func (s *workflowSuite) TestMigrateFailureThenRetrySuccess() { // First attempt fails, second attempt succeeds (on next run loop iteration). migrateCalls := 0 - s.env.OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( + s.env().OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( func(context.Context, *schedulerpb.CreateFromMigrationStateRequest) error { migrateCalls++ if migrateCalls == 1 { @@ -2382,14 +2400,14 @@ func (s *workflowSuite) TestMigrateFailureThenRetrySuccess() { // Enable migration and request it via signal after the first iteration. enableMigration := false - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { enableMigration = true - s.env.SignalWorkflow(SignalNameMigrateToChasm, nil) + s.env().SignalWorkflow(SignalNameMigrateToChasm, nil) }, 1*time.Second) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 100 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { return schedulerWorkflowWithSpecBuilder(ctx, args, NewSpecBuilder(), func() bool { return enableMigration }, func() bool { return true }) }, &schedulespb.StartScheduleArgs{ @@ -2411,15 +2429,15 @@ func (s *workflowSuite) TestMigrateFailureThenRetrySuccess() { // Migration should succeed on second attempt without a new signal, // proving PendingMigration persists across run loop iterations. - s.True(s.env.IsWorkflowCompleted()) - s.Require().NoError(s.env.GetWorkflowError()) + s.True(s.env().IsWorkflowCompleted()) + s.NoError(s.env().GetWorkflowError()) s.Equal(2, migrateCalls, "migration should fail once then succeed on retry") } func (s *workflowSuite) TestMigrateFailureThenSignal() { // Mock MigrateSchedule activity to always fail. migrateCalls := 0 - s.env.OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( + s.env().OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( func(context.Context, *schedulerpb.CreateFromMigrationStateRequest) error { migrateCalls++ return errors.New("migration failed") @@ -2427,31 +2445,31 @@ func (s *workflowSuite) TestMigrateFailureThenSignal() { // Enable migration and request it via signal after the first iteration. enableMigration := false - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { enableMigration = true - s.env.SignalWorkflow(SignalNameMigrateToChasm, nil) + s.env().SignalWorkflow(SignalNameMigrateToChasm, nil) }, 1*time.Second) // After migration failure, send a pause patch and verify it's processed, // proving the workflow kept running and still handles signals. - s.env.RegisterDelayedCallback(func() { - s.env.SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ + s.env().RegisterDelayedCallback(func() { + s.env().SignalWorkflow(SignalNamePatch, &schedulepb.SchedulePatch{ Pause: "paused after failed migration", }) }, 5*time.Second) stillRunning := false - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { desc := s.describe() s.True(desc.Schedule.State.Paused) s.Equal("paused after failed migration", desc.Schedule.State.Notes) - stillRunning = !s.env.IsWorkflowCompleted() + stillRunning = !s.env().IsWorkflowCompleted() // Send force-CAN to unblock the workflow (paused with no timer). - s.env.SignalWorkflow(SignalNameForceCAN, nil) + s.env().SignalWorkflow(SignalNameForceCAN, nil) }, 10*time.Second) CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 100 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { return schedulerWorkflowWithSpecBuilder(ctx, args, NewSpecBuilder(), func() bool { return enableMigration }, func() bool { return true }) }, &schedulespb.StartScheduleArgs{ @@ -2479,23 +2497,23 @@ func (s *workflowSuite) TestMigrateFailureThenSignal() { // Verify PendingMigration is persisted in CAN state. var canErr *workflow.ContinueAsNewError - s.Require().ErrorAs(s.env.GetWorkflowError(), &canErr) + s.ErrorAs(s.env().GetWorkflowError(), &canErr) var canArgs schedulespb.StartScheduleArgs - s.Require().NoError(payloads.Decode(canErr.Input, &canArgs)) + s.NoError(payloads.Decode(canErr.Input, &canArgs)) s.True(canArgs.State.PendingMigration, "PendingMigration should be set in CAN state") } func (s *workflowSuite) TestMigrateDynamicConfig() { // Enable migration by threading enableCHASMMigration=true through the closure (race-safe). // Mock MigrateSchedule activity to succeed. - s.env.OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Once().Return(nil) + s.env().OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Once().Return(nil) prevTweakables := CurrentTweakablePolicies CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 100 defer func() { CurrentTweakablePolicies = prevTweakables }() - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { return schedulerWorkflowWithSpecBuilder(ctx, args, NewSpecBuilder(), func() bool { return true }, func() bool { return true }) }, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ @@ -2515,8 +2533,8 @@ func (s *workflowSuite) TestMigrateDynamicConfig() { }) // Workflow should complete successfully (not CAN) after migration triggered by tweakable. - s.True(s.env.IsWorkflowCompleted()) - s.NoError(s.env.GetWorkflowError()) + s.True(s.env().IsWorkflowCompleted()) + s.NoError(s.env().GetWorkflowError()) } // TestMigrateDynamicConfigFlipsMidRun verifies that the enableCHASMMigration @@ -2526,7 +2544,7 @@ func (s *workflowSuite) TestMigrateDynamicConfig() { func (s *workflowSuite) TestMigrateDynamicConfigFlipsMidRun() { enabled := false migrateCalls := 0 - s.env.OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( + s.env().OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( func(context.Context, *schedulerpb.CreateFromMigrationStateRequest) error { migrateCalls++ return nil @@ -2539,12 +2557,12 @@ func (s *workflowSuite) TestMigrateDynamicConfigFlipsMidRun() { defer func() { CurrentTweakablePolicies = prevTweakables }() // Flip the closure between iteration 1 and iteration 2 (1h interval). - s.env.RegisterDelayedCallback(func() { + s.env().RegisterDelayedCallback(func() { enabled = true }, 30*time.Minute) - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { return schedulerWorkflowWithSpecBuilder(ctx, args, NewSpecBuilder(), func() bool { return enabled }, func() bool { return true }) }, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ @@ -2563,8 +2581,8 @@ func (s *workflowSuite) TestMigrateDynamicConfigFlipsMidRun() { }, }) - s.True(s.env.IsWorkflowCompleted()) - s.Require().NoError(s.env.GetWorkflowError(), "workflow should complete after the dynamic flip triggers migration") + s.True(s.env().IsWorkflowCompleted()) + s.NoError(s.env().GetWorkflowError(), "workflow should complete after the dynamic flip triggers migration") s.Equal(1, migrateCalls, "migration should fire exactly once, after the DC flips") } @@ -2572,7 +2590,7 @@ func (s *workflowSuite) TestMigrateDynamicConfigFailure() { // Enable migration by threading enableCHASMMigration=true through the closure (race-safe), // but activity fails. migrateCalls := 0 - s.env.OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( + s.env().OnActivity(new(activities).MigrateScheduleToChasm, mock.Anything, mock.Anything).Return( func(context.Context, *schedulerpb.CreateFromMigrationStateRequest) error { migrateCalls++ return errors.New("migration failed") @@ -2582,8 +2600,8 @@ func (s *workflowSuite) TestMigrateDynamicConfigFailure() { CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 5 defer func() { CurrentTweakablePolicies = prevTweakables }() - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(func(ctx workflow.Context, args *schedulespb.StartScheduleArgs) error { return schedulerWorkflowWithSpecBuilder(ctx, args, NewSpecBuilder(), func() bool { return true }, func() bool { return true }) }, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ @@ -2603,16 +2621,16 @@ func (s *workflowSuite) TestMigrateDynamicConfigFailure() { }) // Workflow should CAN after all iterations, not terminate. - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) // Migration attempted every iteration. s.Equal(5, migrateCalls) // PendingMigration should be preserved in CAN state. var canErr *workflow.ContinueAsNewError - s.Require().ErrorAs(s.env.GetWorkflowError(), &canErr) + s.ErrorAs(s.env().GetWorkflowError(), &canErr) var canArgs schedulespb.StartScheduleArgs - s.Require().NoError(payloads.Decode(canErr.Input, &canArgs)) + s.NoError(payloads.Decode(canErr.Input, &canArgs)) s.True(canArgs.State.PendingMigration, "PendingMigration should be set in CAN state") } @@ -2625,8 +2643,8 @@ func (s *workflowSuite) TestMigrateDynamicConfigDisabledNoMigration() { // No activity mock registered -- if migration is attempted, the test will fail. CurrentTweakablePolicies.IterationsBeforeContinueAsNew = 3 - s.env.SetStartTime(baseStartTime) - s.env.ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ + s.env().SetStartTime(baseStartTime) + s.env().ExecuteWorkflow(SchedulerWorkflow, &schedulespb.StartScheduleArgs{ Schedule: &schedulepb.Schedule{ Spec: &schedulepb.ScheduleSpec{ Interval: []*schedulepb.IntervalSpec{{ @@ -2644,6 +2662,6 @@ func (s *workflowSuite) TestMigrateDynamicConfigDisabledNoMigration() { }) // Workflow should CAN normally without attempting migration. - s.True(s.env.IsWorkflowCompleted()) - s.True(workflow.IsContinueAsNewError(s.env.GetWorkflowError())) + s.True(s.env().IsWorkflowCompleted()) + s.True(workflow.IsContinueAsNewError(s.env().GetWorkflowError())) }