From 0bf0200b817610f92d62330f935ee5cc712ab7ae Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Wed, 27 May 2026 14:44:18 -0700 Subject: [PATCH 01/16] Replace onebox.go --- common/persistence/client/fx.go | 23 +- common/resource/fx.go | 26 +- common/testing/testhooks/hooks.go | 21 + service/frontend/fx.go | 5 + service/fx.go | 6 + service/history/fx.go | 21 + temporal/fx.go | 35 +- temporal/server_option.go | 7 + temporal/server_options.go | 2 + tests/archival_test.go | 151 +++- tests/dlq_test.go | 33 +- tests/testcore/clients.go | 3 + tests/testcore/functional_test_base.go | 100 +-- tests/testcore/functional_test_base_test.go | 27 +- tests/testcore/onebox.go | 877 +++++++------------- tests/testcore/test_cluster.go | 251 ++---- tests/testcore/test_cluster_pool_test.go | 5 +- tests/testcore/test_env.go | 54 +- tests/tls_test.go | 2 +- tests/xdc/base.go | 1 - 20 files changed, 697 insertions(+), 953 deletions(-) diff --git a/common/persistence/client/fx.go b/common/persistence/client/fx.go index 8cfbb4f8029..3cc9b9228ce 100644 --- a/common/persistence/client/fx.go +++ b/common/persistence/client/fx.go @@ -21,6 +21,7 @@ import ( "go.temporal.io/server/common/quotas" "go.temporal.io/server/common/resolver" otel "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.uber.org/fx" ) @@ -217,9 +218,18 @@ func DataStoreFactoryLifetimeHooks(lc fx.Lifecycle, f persistence.DataStoreFacto lc.Append(fx.StopHook(f.Close)) } -func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, error)) func(Factory, fx.Lifecycle) (T, error) { - return func(f Factory, lc fx.Lifecycle) (T, error) { - manager, err := newManagerFn(f) // passing receiver (Factory) as first argument. +type managerProviderParams struct { + fx.In + + Factory Factory + Lifecycle fx.Lifecycle + TestHooks testhooks.TestHooks `optional:"true"` + Logger log.Logger +} + +func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, error)) func(managerProviderParams) (T, error) { + return func(params managerProviderParams) (T, error) { + manager, err := newManagerFn(params.Factory) // passing receiver (Factory) as first argument. if err != nil { var unimpl *serviceerror.Unimplemented if errors.As(err, &unimpl) { @@ -229,7 +239,12 @@ func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, err var nilT T return nilT, err } - lc.Append(fx.StopHook(manager.Close)) + if executionManager, ok := any(manager).(persistence.ExecutionManager); ok { + if hook, ok := testhooks.Get(params.TestHooks, testhooks.PersistenceExecutionManagerWrapper, testhooks.GlobalScope); ok { + manager = any(hook(executionManager, params.Logger)).(T) + } + } + params.Lifecycle.Append(fx.StopHook(manager.Close)) return manager, nil } } diff --git a/common/resource/fx.go b/common/resource/fx.go index 29e9a756a11..d7a0a929bab 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -225,6 +225,7 @@ func SearchAttributeValidatorProvider( type NamespaceRegistryParams struct { fx.In + ServiceName primitives.ServiceName `optional:"true"` Logger log.SnTaggedLogger MetricsHandler metrics.Handler ClusterMetadata cluster.Metadata @@ -232,10 +233,11 @@ type NamespaceRegistryParams struct { DynamicCollection *dynamicconfig.Collection ReplicationResolverFactory namespace.ReplicationResolverFactory NamespaceStateChangedFn namespace.NamespaceStateChangedFn + TestHooks testhooks.TestHooks `optional:"true"` } func NamespaceRegistryProvider(params NamespaceRegistryParams) namespace.Registry { - return nsregistry.NewRegistry( + registry := nsregistry.NewRegistry( params.MetadataManager, params.ClusterMetadata.IsGlobalNamespaceEnabled(), params.ClusterMetadata.GetCurrentClusterName(), @@ -246,6 +248,10 @@ func NamespaceRegistryProvider(params NamespaceRegistryParams) namespace.Registr params.ReplicationResolverFactory, params.NamespaceStateChangedFn, ) + if hook, ok := testhooks.Get(params.TestHooks, testhooks.NamespaceRegistryCreated, testhooks.GlobalScope); ok { + hook(params.ServiceName, registry) + } + return registry } func ClientFactoryProvider( @@ -326,10 +332,19 @@ func HistoryClientProvider(historyRawClient HistoryRawClient) HistoryClient { } func MatchingRawClientProvider( + serviceName primitives.ServiceName, clientBean client.Bean, namespaceRegistry namespace.Registry, + testHooks testhooks.TestHooks, ) (MatchingRawClient, error) { - return clientBean.GetMatchingClient(namespaceRegistry.GetNamespaceName) + client, err := clientBean.GetMatchingClient(namespaceRegistry.GetNamespaceName) + if err != nil { + return nil, err + } + if hook, ok := testhooks.Get(testHooks, testhooks.MatchingRawClientCreated, testhooks.GlobalScope); ok { + hook(serviceName, client) + } + return client, nil } func MatchingClientProvider(matchingRawClient MatchingRawClient) MatchingClient { @@ -403,13 +418,18 @@ func DCRedirectionPolicyProvider(cfg *config.Config) config.DCRedirectionPolicy func PerServiceDialOptionsProvider( logger log.SnTaggedLogger, + testHooks testhooks.TestHooks, ) map[primitives.ServiceName][]grpc.DialOption { trailerInterceptor := interceptor.TrailerToContextMetadataInterceptor(logger) dialOpt := grpc.WithChainUnaryInterceptor(trailerInterceptor) - return map[primitives.ServiceName][]grpc.DialOption{ + options := map[primitives.ServiceName][]grpc.DialOption{ primitives.HistoryService: {dialOpt}, primitives.MatchingService: {dialOpt}, } + if hook, ok := testhooks.Get(testHooks, testhooks.ServiceClientDialOptions, testhooks.GlobalScope); ok { + hook(options) + } + return options } func RPCFactoryProvider( diff --git a/common/testing/testhooks/hooks.go b/common/testing/testhooks/hooks.go index 349dbc1e1ed..658ad139077 100644 --- a/common/testing/testhooks/hooks.go +++ b/common/testing/testhooks/hooks.go @@ -5,10 +5,24 @@ import ( "time" "go.temporal.io/server/api/historyservice/v1" + "go.temporal.io/server/api/matchingservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" + "go.temporal.io/server/chasm" + "go.temporal.io/server/common/log" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/primitives" historytasks "go.temporal.io/server/service/history/tasks" + "google.golang.org/grpc" +) + +type ( + HistoryChasmComponents struct { + Engine chasm.Engine + VisibilityManager chasm.VisibilityManager + Registry *chasm.Registry + } ) // Test hook keys with their return type and scope. @@ -28,6 +42,13 @@ var ( HistoryTransferTaskInterceptor = newKey[func(historytasks.Task, func()), namespace.ID]() HistoryDLQTaskDeleteInterceptor = newKey[func(context.Context, *historyservice.DeleteDLQTasksRequest, func(context.Context, *historyservice.DeleteDLQTasksRequest) (*historyservice.DeleteDLQTasksResponse, error)) (*historyservice.DeleteDLQTasksResponse, error), global]() NamespaceReplicationTaskInterceptor = newKey[func(context.Context, *replicationspb.NamespaceTaskAttributes, func() error) error, namespace.Name]() + ServiceGrpcInterceptors = newKey[func(primitives.ServiceName, *[]grpc.UnaryServerInterceptor, *[]grpc.StreamServerInterceptor), global]() + ServiceClientDialOptions = newKey[func(map[primitives.ServiceName][]grpc.DialOption), global]() + NamespaceRegistryCreated = newKey[func(primitives.ServiceName, namespace.Registry), global]() + MatchingRawClientCreated = newKey[func(primitives.ServiceName, matchingservice.MatchingServiceClient), global]() + ChasmRegistryInitializer = newKey[func(*chasm.Registry) error, global]() + HistoryChasmComponentsCreated = newKey[func(HistoryChasmComponents), global]() + PersistenceExecutionManagerWrapper = newKey[func(persistence.ExecutionManager, log.Logger) persistence.ExecutionManager, global]() ) // keyID is a unique identifier for a key, used as a map key. diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 53b7094ceb9..c932713dda9 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -246,7 +246,12 @@ func GrpcServerOptionsProvider( customInterceptors []grpc.UnaryServerInterceptor, customStreamInterceptors []grpc.StreamServerInterceptor, metricsHandler metrics.Handler, + testHooks testhooks.TestHooks, ) GrpcServerOptions { + if hook, ok := testhooks.Get(testHooks, testhooks.ServiceGrpcInterceptors, testhooks.GlobalScope); ok { + hook(serviceName, &customInterceptors, &customStreamInterceptors) + } + kep := keepalive.EnforcementPolicy{ MinTime: serviceConfig.KeepAliveMinTime(), PermitWithoutStream: serviceConfig.KeepAlivePermitWithoutStream(), diff --git a/service/fx.go b/service/fx.go index 5368c261ffb..2da1160c9a9 100644 --- a/service/fx.go +++ b/service/fx.go @@ -15,6 +15,7 @@ import ( "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/rpc/interceptor" "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.uber.org/fx" "google.golang.org/grpc" ) @@ -39,6 +40,7 @@ type ( fx.In Logger log.Logger + ServiceName primitives.ServiceName RPCFactory common.RPCFactory ServiceErrorInterceptor *interceptor.ServiceErrorInterceptor RetryableInterceptor *interceptor.RetryableInterceptor @@ -50,6 +52,7 @@ type ( ContextMetadataInterceptor *interceptor.ContextMetadataInterceptor `optional:"true"` AdditionalInterceptors []grpc.UnaryServerInterceptor `optional:"true"` AdditionalStreamInterceptors []grpc.StreamServerInterceptor `optional:"true"` + TestHooks testhooks.TestHooks } ) @@ -124,6 +127,9 @@ func NewPersistenceRateLimitingParams( func GrpcServerOptionsProvider( params GrpcServerOptionsParams, ) []grpc.ServerOption { + if hook, ok := testhooks.Get(params.TestHooks, testhooks.ServiceGrpcInterceptors, testhooks.GlobalScope); ok { + hook(params.ServiceName, ¶ms.AdditionalInterceptors, ¶ms.AdditionalStreamInterceptors) + } grpcServerOptions, err := params.RPCFactory.GetInternodeGRPCServerOptions() if err != nil { diff --git a/service/history/fx.go b/service/history/fx.go index 024f9826aee..3a44a4a4c23 100644 --- a/service/history/fx.go +++ b/service/history/fx.go @@ -34,6 +34,7 @@ import ( "go.temporal.io/server/common/rpc/interceptor" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/tasktoken" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/worker_versioning" "go.temporal.io/server/components/callbacks" hsmnexusoperations "go.temporal.io/server/components/nexusoperations" @@ -99,6 +100,7 @@ var Module = fx.Options( workerdeployment.ClientModule, fx.Provide(RoutingInfoCacheProvider), fx.Invoke(ServiceLifetimeHooks), + fx.Invoke(ServiceRefsHook), callbacks.Module, hsmnexusoperations.Module, @@ -432,6 +434,25 @@ func ChasmVisibilityManagerProvider( ) } +type serviceRefsHookParams struct { + fx.In + + ChasmEngine chasm.Engine + ChasmVisibilityManager chasm.VisibilityManager + ChasmRegistry *chasm.Registry + TestHooks testhooks.TestHooks +} + +func ServiceRefsHook(params serviceRefsHookParams) { + if hook, ok := testhooks.Get(params.TestHooks, testhooks.HistoryChasmComponentsCreated, testhooks.GlobalScope); ok { + hook(testhooks.HistoryChasmComponents{ + Engine: params.ChasmEngine, + VisibilityManager: params.ChasmVisibilityManager, + Registry: params.ChasmRegistry, + }) + } +} + func EventNotifierProvider( timeSource clock.TimeSource, metricsHandler metrics.Handler, diff --git a/temporal/fx.go b/temporal/fx.go index 40a3b37b096..ac9a6490489 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -50,6 +50,7 @@ import ( "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/searchattribute/sadefs" "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/frontend" "go.temporal.io/server/service/history" "go.temporal.io/server/service/history/replication" @@ -125,6 +126,7 @@ type ( TLSConfigProvider encryption.TLSConfigProvider EsClient esclient.Client MetricsHandler metrics.Handler + TestHooks testhooks.TestHooks } ) @@ -220,6 +222,11 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { } } + testHooks := so.testHooks + if testHooks == (testhooks.TestHooks{}) { + testHooks = testhooks.NewTestHooks() + } + // TLSConfigProvider tlsConfigProvider := so.tlsConfigProvider if tlsConfigProvider == nil { @@ -261,10 +268,10 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { } } - // check that when static hosts are defined, they are defined for all required hosts + // check that when static hosts are defined, they are defined for all requested hosts if len(so.hostsByService) > 0 { - for _, service := range DefaultServices { - hosts := so.hostsByService[primitives.ServiceName(service)] + for service := range so.serviceNames { + hosts := so.hostsByService[service] if len(hosts.All) == 0 { return serverOptionsProvider{}, fmt.Errorf("%w: %v", missingServiceInStaticHosts, service) } @@ -314,6 +321,7 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { TLSConfigProvider: tlsConfigProvider, EsClient: esClient, MetricsHandler: metricHandler, + TestHooks: testHooks, }, nil } @@ -380,6 +388,7 @@ type ( InstanceID resource.InstanceID `optional:"true"` StaticServiceHosts map[primitives.ServiceName]static.Hosts `optional:"true"` TaskCategoryRegistry tasks.TaskCategoryRegistry + TestHooks testhooks.TestHooks } ) @@ -462,14 +471,32 @@ func (params ServiceProviderParamsCommon) GetCommonServiceOptions(serviceName pr return params.TaskCategoryRegistry }, ), + fx.Decorate(func() testhooks.TestHooks { + return params.TestHooks + }), ServiceTracingModule, resource.DefaultOptions, membershipModule, FxLogAdapter, chasm.Module, + fx.Invoke(ChasmRegistryInitializerHook), ) } +type chasmRegistryInitializerHookParams struct { + fx.In + + Registry *chasm.Registry + TestHooks testhooks.TestHooks +} + +func ChasmRegistryInitializerHook(params chasmRegistryInitializerHookParams) error { + if hook, ok := testhooks.Get(params.TestHooks, testhooks.ChasmRegistryInitializer, testhooks.GlobalScope); ok { + return hook(params.Registry) + } + return nil +} + // TaskCategoryRegistryProvider provides an immutable tasks.TaskCategoryRegistry to the server, which is intended to be // shared by each service. Why do we need to initialize this at the top-level? Because, even though the presence of the // archival task category is only needed by the history service, which must conditionally start a queue processor for @@ -557,7 +584,7 @@ func genericFrontendServiceProvider( app := fx.New( params.GetCommonServiceOptions(serviceName), fx.Supply(params.CustomFrontendInterceptors), - fx.Supply([]grpc.StreamServerInterceptor{}), + fx.Supply([]grpc.StreamServerInterceptor(nil)), fx.Decorate(func() authorization.ClaimMapper { switch serviceName { case primitives.FrontendService: diff --git a/temporal/server_option.go b/temporal/server_option.go index 54ffda44b48..11d9d5a2c08 100644 --- a/temporal/server_option.go +++ b/temporal/server_option.go @@ -18,6 +18,7 @@ import ( "go.temporal.io/server/common/rpc/auth" "go.temporal.io/server/common/rpc/encryption" "go.temporal.io/server/common/searchattribute" + "go.temporal.io/server/common/testing/testhooks" "google.golang.org/grpc" ) @@ -206,6 +207,12 @@ func WithTokenProvider(tp auth.TokenProvider) ServerOption { }) } +func WithTestHooks(testHooks testhooks.TestHooks) ServerOption { + return applyFunc(func(s *serverOptions) { + s.testHooks = testHooks + }) +} + // WithCustomerMetricsProvider sets a custom implementation of the metrics.MetricsHandler interface // metrics.MetricsHandler is the base interface for publishing metric events func WithCustomMetricsHandler(provider metrics.Handler) ServerOption { diff --git a/temporal/server_options.go b/temporal/server_options.go index 1ac2727f41b..96fa6d8ed3f 100644 --- a/temporal/server_options.go +++ b/temporal/server_options.go @@ -21,6 +21,7 @@ import ( "go.temporal.io/server/common/rpc/auth" "go.temporal.io/server/common/rpc/encryption" "go.temporal.io/server/common/searchattribute" + "go.temporal.io/server/common/testing/testhooks" "google.golang.org/grpc" ) @@ -60,6 +61,7 @@ type ( customFrontendInterceptors []grpc.UnaryServerInterceptor metricHandler metrics.Handler tokenProvider auth.TokenProvider + testHooks testhooks.TestHooks } ) diff --git a/tests/archival_test.go b/tests/archival_test.go index e2debb413a2..56ff367dffe 100644 --- a/tests/archival_test.go +++ b/tests/archival_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "fmt" + "os" "strconv" "sync/atomic" "testing" @@ -23,16 +24,21 @@ import ( "go.temporal.io/server/chasm" "go.temporal.io/server/common" "go.temporal.io/server/common/archiver" + "go.temporal.io/server/common/archiver/filestore" "go.temporal.io/server/common/archiver/provider" + "go.temporal.io/server/common/config" "go.temporal.io/server/common/convert" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/payloads" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/versionhistory" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/testing/protoassert" + "go.temporal.io/server/temporal" "go.temporal.io/server/tests/testcore" "google.golang.org/protobuf/types/known/durationpb" ) @@ -56,6 +62,19 @@ type ( // Counters to verify custom archivers are being called customHistoryArchiveCalled atomic.Int32 customVisibilityArchiveCalled atomic.Int32 + + archiverBase *archivalArchiverBase + } + + archivalArchiverBase struct { + metadata archiver.ArchivalMetadata + provider provider.ArchiverProvider + historyProvider *config.HistoryArchiverProvider + visibilityProvider *config.VisibilityArchiverProvider + historyStoreDirectory string + visibilityStoreDirectory string + historyURI string + visibilityURI string } archivalWorkflowInfo struct { @@ -74,6 +93,94 @@ type ( } ) +func newArchivalArchiverBase(t *testing.T) *archivalArchiverBase { + t.Helper() + + historyStoreDirectory, err := os.MkdirTemp("", "test-history-archival") + if err != nil { + t.Fatal(err) + } + visibilityStoreDirectory, err := os.MkdirTemp("", "test-visibility-archival") + if err != nil { + t.Fatal(err) + } + cfg := &config.FilestoreArchiver{ + FileMode: "0666", + DirMode: "0766", + } + historyProvider := &config.HistoryArchiverProvider{ + Filestore: cfg, + } + visibilityProvider := &config.VisibilityArchiverProvider{ + Filestore: cfg, + } + historyURI := filestore.URIScheme + "://" + historyStoreDirectory + visibilityURI := filestore.URIScheme + "://" + visibilityStoreDirectory + return &archivalArchiverBase{ + metadata: archiver.NewArchivalMetadata(dynamicconfig.NewNoopCollection(), "enabled", true, "enabled", true, &config.ArchivalNamespaceDefaults{ + History: config.HistoryArchivalNamespaceDefaults{ + State: "enabled", + URI: historyURI, + }, + Visibility: config.VisibilityArchivalNamespaceDefaults{ + State: "enabled", + URI: visibilityURI, + }, + }), + historyProvider: historyProvider, + visibilityProvider: visibilityProvider, + historyStoreDirectory: historyStoreDirectory, + visibilityStoreDirectory: visibilityStoreDirectory, + historyURI: historyURI, + visibilityURI: visibilityURI, + } +} + +func (a *archivalArchiverBase) applyConfig(cfg *config.Config) { + cfg.Archival.History = config.HistoryArchival{ + State: config.ArchivalEnabled, + EnableRead: true, + Provider: a.historyProvider, + } + cfg.Archival.Visibility = config.VisibilityArchival{ + State: config.ArchivalEnabled, + EnableRead: true, + Provider: a.visibilityProvider, + } + cfg.NamespaceDefaults.Archival.History = config.HistoryArchivalNamespaceDefaults{ + State: config.ArchivalEnabled, + URI: a.historyURI, + } + cfg.NamespaceDefaults.Archival.Visibility = config.VisibilityArchivalNamespaceDefaults{ + State: config.ArchivalEnabled, + URI: a.visibilityURI, + } +} + +func (a *archivalArchiverBase) initProvider( + executionManager persistence.ExecutionManager, + customHistoryArchiverFactory provider.CustomHistoryArchiverFactory, + customVisibilityArchiverFactory provider.CustomVisibilityArchiverFactory, +) { + a.provider = provider.NewArchiverProvider( + a.historyProvider, + a.visibilityProvider, + customHistoryArchiverFactory, + customVisibilityArchiverFactory, + executionManager, + log.NewNoopLogger(), + metrics.NoopMetricsHandler, + ) +} + +func (a *archivalArchiverBase) tearDown() error { + err := os.RemoveAll(a.historyStoreDirectory) + if visibilityErr := os.RemoveAll(a.visibilityStoreDirectory); visibilityErr != nil && err == nil { + err = visibilityErr + } + return err +} + // customHistoryArchiver method implementations func (c *customHistoryArchiver) Archive(ctx context.Context, uri archiver.URI, request *archiver.ArchiveHistoryRequest, opts ...archiver.ArchiveOption) error { c.counter.Add(1) @@ -103,11 +210,13 @@ func (c *customVisibilityArchiver) ValidateURI(uri archiver.URI) error { } func TestArchivalSuite(t *testing.T) { - t.Parallel() // This suite can work in parallel as long as it is the only one that use testcore.WithArchivalEnabled() option. + t.Parallel() suite.Run(t, new(ArchivalSuite)) } func (s *ArchivalSuite) SetupSuite() { + s.archiverBase = newArchivalArchiverBase(s.T()) + dynamicConfigOverrides := map[dynamicconfig.Key]any{ dynamicconfig.ArchivalProcessorArchiveDelay.Key(): time.Duration(0), } @@ -142,9 +251,16 @@ func (s *ArchivalSuite) SetupSuite() { s.FunctionalTestBase.SetupSuiteWithCluster( testcore.WithDynamicConfigOverrides(dynamicConfigOverrides), - testcore.WithArchivalEnabled(), - testcore.WithCustomHistoryArchiverFactory(customHistoryArchiverFactory), - testcore.WithCustomVisibilityArchiverFactory(customVisibilityArchiverFactory), + testcore.WithServerConfigOverride(s.archiverBase.applyConfig), + testcore.WithServerOptions( + temporal.WithCustomHistoryArchiverFactory(customHistoryArchiverFactory), + temporal.WithCustomVisibilityArchiverFactory(customVisibilityArchiverFactory), + ), + ) + s.archiverBase.initProvider( + s.GetTestCluster().ExecutionManager(), + customHistoryArchiverFactory, + customVisibilityArchiverFactory, ) var err error @@ -155,8 +271,8 @@ func (s *ArchivalSuite) SetupSuite() { s.archivalNamespace, 0, // Archive right away. enumspb.ARCHIVAL_STATE_ENABLED, - s.GetTestCluster().ArchiverBase().HistoryURI(), - s.GetTestCluster().ArchiverBase().VisibilityURI(), + s.archiverBase.historyURI, + s.archiverBase.visibilityURI, ) s.Require().NoError(err) @@ -175,13 +291,18 @@ func (s *ArchivalSuite) SetupSuite() { } func (s *ArchivalSuite) TearDownSuite() { + defer func() { + if s.archiverBase != nil { + s.Require().NoError(s.archiverBase.tearDown()) + } + }() s.Require().NoError(s.MarkNamespaceAsDeleted(s.archivalNamespace)) s.Require().NoError(s.MarkNamespaceAsDeleted(s.customArchiverNamespace)) s.FunctionalTestBase.TearDownCluster() } func (s *ArchivalSuite) TestArchival_TimerQueueProcessor() { - s.True(s.GetTestCluster().ArchiverBase().Metadata().GetHistoryConfig().ClusterConfiguredForArchival()) + s.True(s.archiverBase.metadata.GetHistoryConfig().ClusterConfiguredForArchival()) workflowID := "archival-timer-queue-processor-workflow-id" workflowType := "archival-timer-queue-processor-type" @@ -196,7 +317,7 @@ func (s *ArchivalSuite) TestArchival_TimerQueueProcessor() { } func (s *ArchivalSuite) TestArchival_ContinueAsNew() { - s.True(s.GetTestCluster().ArchiverBase().Metadata().GetHistoryConfig().ClusterConfiguredForArchival()) + s.True(s.archiverBase.metadata.GetHistoryConfig().ClusterConfiguredForArchival()) workflowID := "archival-continueAsNew-workflow-id" workflowType := "archival-continueAsNew-workflow-type" @@ -215,7 +336,7 @@ func (s *ArchivalSuite) TestArchival_ContinueAsNew() { func (s *ArchivalSuite) TestArchival_ArchiverWorker() { // s.T().SkipNow() // flaky test, skip for now, will reimplement archival feature. - s.True(s.GetTestCluster().ArchiverBase().Metadata().GetHistoryConfig().ClusterConfiguredForArchival()) + s.True(s.archiverBase.metadata.GetHistoryConfig().ClusterConfiguredForArchival()) workflowID := "archival-archiver-worker-workflow-id" workflowType := "archival-archiver-worker-workflow-type" @@ -229,7 +350,7 @@ func (s *ArchivalSuite) TestArchival_ArchiverWorker() { } func (s *ArchivalSuite) TestVisibilityArchival() { - s.True(s.GetTestCluster().ArchiverBase().Metadata().GetVisibilityConfig().ClusterConfiguredForArchival()) + s.True(s.archiverBase.metadata.GetVisibilityConfig().ClusterConfiguredForArchival()) workflowID := "archival-visibility-workflow-id" workflowType := "archival-visibility-workflow-type" @@ -277,7 +398,7 @@ func (s *ArchivalSuite) TestVisibilityArchival() { } func (s *ArchivalSuite) TestCustomArchiver() { - s.True(s.GetTestCluster().ArchiverBase().Metadata().GetHistoryConfig().ClusterConfiguredForArchival()) + s.True(s.archiverBase.metadata.GetHistoryConfig().ClusterConfiguredForArchival()) workflowID := "custom-history-archiver-workflow-id" workflowType := "custom-history-archiver-type" @@ -305,16 +426,16 @@ func (s *ArchivalSuite) TestCustomArchiver() { // workflowIsArchived asserts that both the workflow history and workflow visibility are archived. func (s *ArchivalSuite) workflowIsArchived(namespaceID namespace.ID, execution *commonpb.WorkflowExecution) { - historyURI, err := archiver.NewURI(s.GetTestCluster().ArchiverBase().HistoryURI()) + historyURI, err := archiver.NewURI(s.archiverBase.historyURI) s.NoError(err) - historyArchiver, err := s.GetTestCluster().ArchiverBase().Provider().GetHistoryArchiver( + historyArchiver, err := s.archiverBase.provider.GetHistoryArchiver( historyURI.Scheme(), ) s.NoError(err) - visibilityURI, err := archiver.NewURI(s.GetTestCluster().ArchiverBase().VisibilityURI()) + visibilityURI, err := archiver.NewURI(s.archiverBase.visibilityURI) s.NoError(err) - visibilityArchiver, err := s.GetTestCluster().ArchiverBase().Provider().GetVisibilityArchiver( + visibilityArchiver, err := s.archiverBase.provider.GetVisibilityArchiver( visibilityURI.Scheme(), ) s.NoError(err) diff --git a/tests/dlq_test.go b/tests/dlq_test.go index 90b8410ff44..5bff146e797 100644 --- a/tests/dlq_test.go +++ b/tests/dlq_test.go @@ -29,7 +29,6 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/primitives" - "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/testing/await" "go.temporal.io/server/common/testing/parallelsuite" "go.temporal.io/server/common/testing/testhooks" @@ -38,7 +37,6 @@ import ( "go.temporal.io/server/tests/testutils" "go.temporal.io/server/tools/tdbg" "go.temporal.io/server/tools/tdbg/tdbgtest" - "go.uber.org/fx" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) @@ -50,10 +48,10 @@ type ( dlqTestEnv struct { *testcore.TestEnv - dlq persistence.HistoryTaskQueueManager - writer bytes.Buffer - sdkClientFactory sdk.ClientFactory - deleteBlockCh chan any + dlq persistence.HistoryTaskQueueManager + writer bytes.Buffer + systemSDKClient sdkclient.Client + deleteBlockCh chan any failingWorkflowIDPrefix atomic.Pointer[string] } @@ -93,16 +91,21 @@ func (s *DLQSuite) newTestEnv(opts ...testcore.TestOption) *dlqTestEnv { return serialization.NewDeserializationError(enumspb.ENCODING_TYPE_PROTO3, errors.New("test error")) }, }), - testcore.WithFxOptions(primitives.HistoryService, - fx.Populate(&w.dlq), - ), - testcore.WithFxOptions(primitives.FrontendService, - fx.Populate(&w.sdkClientFactory), - ), } w.TestEnv = testcore.NewEnv(s.T(), append(baseOpts, opts...)...) w.SdkWorker().RegisterWorkflow(s.myWorkflow) + var err error + w.dlq, err = w.GetTestCluster().TestBase().Factory.NewHistoryTaskQueueManager() + s.NoError(err) + s.T().Cleanup(w.dlq.Close) + w.systemSDKClient, err = sdkclient.Dial(sdkclient.Options{ + HostPort: w.FrontendGRPCAddress(), + Namespace: primitives.SystemLocalNamespace, + }) + s.NoError(err) + s.T().Cleanup(w.systemSDKClient.Close) + w.deleteBlockCh = make(chan any) close(w.deleteBlockCh) @@ -460,8 +463,7 @@ func (s *DLQSuite) purgeMessages(env *dlqTestEnv, maxMessageIDToDelete int64) st var token adminservice.DLQJobToken s.NoError(proto.Unmarshal(response.GetJobToken(), &token)) - systemSDKClient := env.sdkClientFactory.GetSystemClient() - run := systemSDKClient.GetWorkflow(s.Context(), token.WorkflowId, token.RunId) + run := env.systemSDKClient.GetWorkflow(s.Context(), token.WorkflowId, token.RunId) s.NoError(run.Get(s.Context(), nil)) return tokenString } @@ -473,8 +475,7 @@ func (s *DLQSuite) mergeMessages(env *dlqTestEnv, maxMessageID int64) string { s.NoError(err) var token adminservice.DLQJobToken s.NoError(token.Unmarshal(tokenBytes)) - systemSDKClient := env.sdkClientFactory.GetSystemClient() - run := systemSDKClient.GetWorkflow(s.Context(), token.WorkflowId, token.RunId) + run := env.systemSDKClient.GetWorkflow(s.Context(), token.WorkflowId, token.RunId) s.NoError(run.Get(s.Context(), nil)) return tokenString } diff --git a/tests/testcore/clients.go b/tests/testcore/clients.go index 556def736ef..179ea93f8cd 100644 --- a/tests/testcore/clients.go +++ b/tests/testcore/clients.go @@ -113,6 +113,9 @@ func (c *clients) ensureHistory() { } func (c *clients) MatchingClient() matchingservice.MatchingServiceClient { + if c.matching.client == nil { + c.logger.Fatal("matching test client has not been initialized") + } return c.matching.client } diff --git a/tests/testcore/functional_test_base.go b/tests/testcore/functional_test_base.go index e7440b04d69..36d3f4e9486 100644 --- a/tests/testcore/functional_test_base.go +++ b/tests/testcore/functional_test_base.go @@ -24,7 +24,6 @@ import ( "go.temporal.io/server/api/adminservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common" - "go.temporal.io/server/common/archiver/provider" "go.temporal.io/server/common/config" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" @@ -33,8 +32,6 @@ import ( "go.temporal.io/server/common/payloads" "go.temporal.io/server/common/persistence" persistencetests "go.temporal.io/server/common/persistence/persistence-tests" - "go.temporal.io/server/common/persistence/sql/sqlplugin/sqlite" - "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/searchattribute" @@ -47,7 +44,7 @@ import ( "go.temporal.io/server/common/testing/testtelemetry" "go.temporal.io/server/common/testing/updateutils" "go.temporal.io/server/components/nexusoperations" - "go.uber.org/fx" + "go.temporal.io/server/temporal" ) type ( @@ -93,18 +90,16 @@ type ( } // TestClusterParams contains the variables which are used to configure test cluster via the TestClusterOption type. TestClusterParams struct { - ServiceOptions map[primitives.ServiceName][]fx.Option - DCRedirectionPolicy config.DCRedirectionPolicy - DynamicConfigOverrides map[dynamicconfig.Key]any - ArchivalEnabled bool - EnableMTLS bool - EnableWorkerService bool - FaultInjectionConfig *config.FaultInjection - NumHistoryShards int32 - Logger log.Logger - SharedCluster bool - CustomHistoryArchiverFactory provider.CustomHistoryArchiverFactory - CustomVisibilityArchiverFactory provider.CustomVisibilityArchiverFactory + DCRedirectionPolicy config.DCRedirectionPolicy + DynamicConfigOverrides map[dynamicconfig.Key]any + EnableMTLS bool + EnableWorkerService bool + FaultInjectionConfig *config.FaultInjection + NumHistoryShards int32 + Logger log.Logger + SharedCluster bool + ServerConfigOverride func(*config.Config) + ServerOptions []temporal.ServerOption } TestClusterOption func(params *TestClusterParams) ) @@ -116,25 +111,6 @@ func init() { sdkworker.SetBinaryChecksum("oss-server-test") } -// WithFxOptionsForService returns an Option which, when passed as an argument to setupSuite, will append the given list -// of fx options to the end of the arguments to the fx.New call for the given service. For example, if you want to -// obtain the shard controller for the history service, you can do this: -// -// var shardController shard.Controller -// s.setupSuite(t, tests.WithFxOptionsForService(primitives.HistoryService, fx.Populate(&shardController))) -// // now you can use shardController during your test -// -// This is similar to the pattern of plumbing dependencies through the TestClusterConfig, but it's much more convenient, -// scalable and flexible. The reason we need to do this on a per-service basis is that there are separate fx apps for -// each one. -// -// Deprecated: prefer dedicated TestClusterOption helpers or testhooks over injecting arbitrary Fx options. -func WithFxOptionsForService(serviceName primitives.ServiceName, options ...fx.Option) TestClusterOption { - return func(params *TestClusterParams) { - params.ServiceOptions[serviceName] = append(params.ServiceOptions[serviceName], options...) - } -} - func WithDCRedirectionPolicy(policy config.DCRedirectionPolicy) TestClusterOption { return func(params *TestClusterParams) { params.DCRedirectionPolicy = policy @@ -151,13 +127,19 @@ func WithDynamicConfigOverrides(overrides map[dynamicconfig.Key]any) TestCluster } } -func WithArchivalEnabled() TestClusterOption { +func WithServerConfigOverride(override func(*config.Config)) TestClusterOption { return func(params *TestClusterParams) { - params.ArchivalEnabled = true + params.ServerConfigOverride = override } } -func withMTLS() TestClusterOption { +func WithServerOptions(options ...temporal.ServerOption) TestClusterOption { + return func(params *TestClusterParams) { + params.ServerOptions = append(params.ServerOptions, options...) + } +} + +func WithMTLS() TestClusterOption { return func(params *TestClusterParams) { params.EnableMTLS = true } @@ -195,18 +177,6 @@ func WithSharedCluster() TestClusterOption { } } -func WithCustomHistoryArchiverFactory(factory provider.CustomHistoryArchiverFactory) TestClusterOption { - return func(params *TestClusterParams) { - params.CustomHistoryArchiverFactory = factory - } -} - -func WithCustomVisibilityArchiverFactory(factory provider.CustomVisibilityArchiverFactory) TestClusterOption { - return func(params *TestClusterParams) { - params.CustomVisibilityArchiverFactory = factory - } -} - func (s *FunctionalTestBase) GetTestCluster() *TestCluster { return s.testCluster } @@ -315,20 +285,19 @@ func (s *FunctionalTestBase) setupCluster(options ...TestClusterOption) { HistoryConfig: HistoryConfig{ NumHistoryShards: cmp.Or(params.NumHistoryShards, 4), }, - DCRedirectionPolicy: params.DCRedirectionPolicy, - DynamicConfigOverrides: params.DynamicConfigOverrides, - ServiceFxOptions: params.ServiceOptions, - EnableMetricsCapture: true, - EnableArchival: params.ArchivalEnabled, - EnableMTLS: params.EnableMTLS, - CustomHistoryArchiverFactory: params.CustomHistoryArchiverFactory, - CustomVisibilityArchiverFactory: params.CustomVisibilityArchiverFactory, - WorkerConfig: WorkerConfig{DisableWorker: !params.EnableWorkerService}, + DCRedirectionPolicy: params.DCRedirectionPolicy, + DynamicConfigOverrides: params.DynamicConfigOverrides, + EnableMetricsCapture: true, + EnableMTLS: params.EnableMTLS, + ServerConfigOverride: params.ServerConfigOverride, + ServerOptions: params.ServerOptions, + WorkerConfig: WorkerConfig{DisableWorker: !params.EnableWorkerService}, } // Apply configuration for shared clusters. if params.SharedCluster { - s.testClusterConfig.Persistence = sharedClusterPersistence(GetPersistenceTestDefaults()) + // Use file-based SQLite for shared clusters to support parallel test access. + s.testClusterConfig.Persistence = *persistencetests.GetSQLiteFileTestClusterOption() s.isShared = true } @@ -359,14 +328,6 @@ func (s *FunctionalTestBase) setupCluster(options ...TestClusterOption) { s.Require().NoError(err) } -func sharedClusterPersistence(defaults persistencetests.TestBaseOptions) persistencetests.TestBaseOptions { - if defaults.StoreType == config.StoreTypeSQL && defaults.SQLDBPluginName == sqlite.PluginName { - // Use file-based SQLite for shared clusters to support parallel test access. - return *persistencetests.GetSQLiteFileTestClusterOption() - } - return defaults -} - // All test suites that inherit FunctionalTestBase and overwrite SetupTest must // call this testcore FunctionalTestBase.SetupTest function to distribute the tests // into partitions. Otherwise, the test suite will be executed multiple times @@ -403,7 +364,6 @@ func (s *FunctionalTestBase) checkTestShard() { func ApplyTestClusterOptions(options []TestClusterOption) TestClusterParams { params := TestClusterParams{ - ServiceOptions: make(map[primitives.ServiceName][]fx.Option), EnableWorkerService: true, } for _, opt := range options { @@ -751,7 +711,7 @@ func (s *FunctionalTestBase) SendSignal(nsName string, execution *commonpb.Workf // RegisterTest records t as currently using this cluster. At t's Cleanup it // fails t if the cluster was poisoned during t's window. This fails all active tests // currently running. The cluster will be torn down if t was the last active test on a poisoned cluster. -// The cluster pool's slot reference is replaced as soon as poison is observed. +// The pool's slot reference is replaced as soon as poison is observed. func (s *FunctionalTestBase) RegisterTest(t testlogger.CleanupCapableT) { if s.t != nil { s.t.addTest(t) diff --git a/tests/testcore/functional_test_base_test.go b/tests/testcore/functional_test_base_test.go index 20f90c7be36..aa51d79f3ac 100644 --- a/tests/testcore/functional_test_base_test.go +++ b/tests/testcore/functional_test_base_test.go @@ -6,9 +6,7 @@ import ( "time" "github.com/stretchr/testify/suite" - "go.temporal.io/server/common/primitives" "go.temporal.io/server/service/worker" - "go.uber.org/fx" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" healthpb "google.golang.org/grpc/health/grpc_health_v1" @@ -16,11 +14,6 @@ import ( type FunctionalTestBaseSuite struct { FunctionalTestBase - - frontendServiceName primitives.ServiceName - matchingServiceName primitives.ServiceName - historyServiceName primitives.ServiceName - workerServiceName primitives.ServiceName } func TestFunctionalTestBaseSuite(t *testing.T) { @@ -29,12 +22,7 @@ func TestFunctionalTestBaseSuite(t *testing.T) { } func (s *FunctionalTestBaseSuite) SetupSuite() { - s.FunctionalTestBase.SetupSuiteWithCluster( - WithFxOptionsForService(primitives.FrontendService, fx.Populate(&s.frontendServiceName)), - WithFxOptionsForService(primitives.MatchingService, fx.Populate(&s.matchingServiceName)), - WithFxOptionsForService(primitives.HistoryService, fx.Populate(&s.historyServiceName)), - WithFxOptionsForService(primitives.WorkerService, fx.Populate(&s.workerServiceName)), - ) + s.SetupSuiteWithCluster() } func (s *FunctionalTestBaseSuite) TearDownSuite() { @@ -45,19 +33,6 @@ func (s *FunctionalTestBaseSuite) SetupTest() { s.FunctionalTestBase.SetupTest() } -func (s *FunctionalTestBaseSuite) TestWithFxOptionsForService() { - // This test works by using the WithFxOptionsForService option to obtain the ServiceName from the graph, and then - // it verifies that the ServiceName is correct. It does this because we are targeting the fx.App for a particular - // service, so we'll know our fx options were provided to the right service if, when we use them to get the current - // service name, it matches the target service. A more realistic example would use the option to obtain an actual - // useful object like a history shard controller, or do some graph modifications with fx.Decorate. - - s.Equal(primitives.FrontendService, s.frontendServiceName) - s.Equal(primitives.MatchingService, s.matchingServiceName) - s.Equal(primitives.HistoryService, s.historyServiceName) - s.Equal(primitives.WorkerService, s.workerServiceName) -} - func (s *FunctionalTestBaseSuite) TestWorkerServiceHealthCheck() { // This test verifies that the worker service exposes a working gRPC health check endpoint. conn, err := grpc.NewClient( diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index b94fc3159e2..3ac1b720a66 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -2,27 +2,27 @@ package testcore import ( "context" - "crypto/tls" "encoding/json" "fmt" "maps" "math/rand" "net" - "slices" "strconv" "sync" "testing" "time" - sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.temporal.io/api/operatorservice/v1" + "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/api/historyservice/v1" + "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/chasm" chasmnexus "go.temporal.io/server/chasm/lib/nexusoperation" + schedulerpb "go.temporal.io/server/chasm/lib/scheduler/gen/schedulerpb/v1" chasmtests "go.temporal.io/server/chasm/lib/tests" "go.temporal.io/server/client" "go.temporal.io/server/common" - carchiver "go.temporal.io/server/common/archiver" - "go.temporal.io/server/common/archiver/provider" "go.temporal.io/server/common/authorization" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/config" @@ -34,84 +34,50 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/metrics/metricstest" "go.temporal.io/server/common/namespace" - "go.temporal.io/server/common/namespace/nsreplication" "go.temporal.io/server/common/persistence" persistenceClient "go.temporal.io/server/common/persistence/client" "go.temporal.io/server/common/persistence/visibility" - esclient "go.temporal.io/server/common/persistence/visibility/store/elasticsearch/client" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/resolver" - "go.temporal.io/server/common/resource" - "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/rpc/auth" "go.temporal.io/server/common/rpc/encryption" - "go.temporal.io/server/common/sdk" - "go.temporal.io/server/common/searchattribute" - "go.temporal.io/server/common/telemetry" "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/components/nexusoperations" - "go.temporal.io/server/service/frontend" - "go.temporal.io/server/service/history" - "go.temporal.io/server/service/history/replication" - "go.temporal.io/server/service/history/tasks" - "go.temporal.io/server/service/matching" - "go.temporal.io/server/service/worker" "go.temporal.io/server/temporal" - "go.uber.org/fx" "go.uber.org/multierr" "google.golang.org/grpc" ) type ( TemporalImpl struct { - clients - fxApps []*fx.App - // This is used to wait for namespace registries to have noticed a change in some xdc tests. namespaceRegistries []namespace.Registry - // Address for SDK to connect to, using membership grpc resolver. - frontendMembershipAddress string - chasmEngine chasm.Engine - chasmVisibilityMgr chasm.VisibilityManager - - dcClient *dynamicconfig.MemoryClient - testHooks testhooks.TestHooks - logger log.Logger - clusterMetadataConfig *cluster.Config - persistenceConfig config.Persistence - metadataMgr persistence.MetadataManager - clusterMetadataMgr persistence.ClusterMetadataManager - shardMgr persistence.ShardManager - taskMgr persistence.TaskManager - executionManager persistence.ExecutionManager - namespaceReplicationQueue persistence.NamespaceReplicationQueue - abstractDataStoreFactory persistenceClient.AbstractDataStoreFactory - visibilityStoreFactory visibility.VisibilityStoreFactory - archiverMetadata carchiver.ArchivalMetadata - archiverProvider provider.ArchiverProvider - frontendConfig FrontendConfig - historyConfig HistoryConfig - matchingConfig MatchingConfig - workerConfig WorkerConfig - esConfig *esclient.Config - esClient esclient.Client - mockAdminClient map[string]adminservice.AdminServiceClient - namespaceReplicationTaskExecutor nsreplication.TaskExecutor - dcRedirectionPolicy config.DCRedirectionPolicy - tlsConfigProvider *encryption.FixedTLSConfigProvider - captureMetricsHandler *metricstest.CaptureHandler - hostsByProtocolByService map[transferProtocol]map[primitives.ServiceName]static.Hosts + chasmEngine chasm.Engine + chasmVisibilityMgr chasm.VisibilityManager + + clients clients + + dcClient *dynamicconfig.MemoryClient + testHooks testhooks.TestHooks + logger log.Logger + config *config.Config + abstractDataStoreFactory persistenceClient.AbstractDataStoreFactory + visibilityStoreFactory visibility.VisibilityStoreFactory + mockAdminClient map[string]adminservice.AdminServiceClient + tlsConfigProvider *encryption.FixedTLSConfigProvider + tokenProvider auth.TokenProvider + captureMetricsHandler *metricstest.CaptureHandler + hostsByProtocolByService map[transferProtocol]map[primitives.ServiceName]static.Hosts + serverOptions []temporal.ServerOption onGetClaims func(*authorization.AuthInfo) (*authorization.Claims, error) onAuthorize func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) callbackLock sync.RWMutex // Must be used for above callbacks - serviceFxOptions map[primitives.ServiceName][]fx.Option - taskCategoryRegistry tasks.TaskCategoryRegistry chasmRegistry *chasm.Registry replicationStreamRecorder *ReplicationStreamRecorder taskQueueRecorder *TaskQueueRecorder - spanExporters map[telemetry.SpanExporterType]sdktrace.SpanExporter - tokenProvider auth.TokenProvider + + servers []*temporal.ServerFx } // FrontendConfig is the config for the frontend service @@ -138,42 +104,20 @@ type ( // TemporalParams contains everything needed to bootstrap Temporal TemporalParams struct { - ClusterMetadataConfig *cluster.Config - PersistenceConfig config.Persistence - MetadataMgr persistence.MetadataManager - ClusterMetadataManager persistence.ClusterMetadataManager - ShardMgr persistence.ShardManager - ExecutionManager persistence.ExecutionManager - TaskMgr persistence.TaskManager - NamespaceReplicationQueue persistence.NamespaceReplicationQueue - AbstractDataStoreFactory persistenceClient.AbstractDataStoreFactory - VisibilityStoreFactory visibility.VisibilityStoreFactory - Logger log.Logger - ArchiverMetadata carchiver.ArchivalMetadata - ArchiverProvider provider.ArchiverProvider - EnableReadHistoryFromArchival bool - FrontendConfig FrontendConfig - HistoryConfig HistoryConfig - MatchingConfig MatchingConfig - WorkerConfig WorkerConfig - ESConfig *esclient.Config - ESClient esclient.Client - MockAdminClient map[string]adminservice.AdminServiceClient - NamespaceReplicationTaskExecutor nsreplication.TaskExecutor - DCRedirectionPolicy config.DCRedirectionPolicy - DynamicConfigOverrides map[dynamicconfig.Key]any - TLSConfigProvider *encryption.FixedTLSConfigProvider - CaptureMetricsHandler *metricstest.CaptureHandler - // ServiceFxOptions is populated by WithFxOptionsForService. - ServiceFxOptions map[primitives.ServiceName][]fx.Option - TaskCategoryRegistry tasks.TaskCategoryRegistry - HostsByProtocolByService map[transferProtocol]map[primitives.ServiceName]static.Hosts - SpanExporters map[telemetry.SpanExporterType]sdktrace.SpanExporter + Config *config.Config + AbstractDataStoreFactory persistenceClient.AbstractDataStoreFactory + VisibilityStoreFactory visibility.VisibilityStoreFactory + Logger log.Logger + MockAdminClient map[string]adminservice.AdminServiceClient + DynamicConfigOverrides map[dynamicconfig.Key]any + TLSConfigProvider *encryption.FixedTLSConfigProvider TokenProvider auth.TokenProvider + CaptureMetricsHandler *metricstest.CaptureHandler + HostsByProtocolByService map[transferProtocol]map[primitives.ServiceName]static.Hosts + ServerOptions []temporal.ServerOption } - listenHostPort string - httpPort int + httpPort int ) const NamespaceCacheRefreshInterval = time.Second @@ -181,57 +125,40 @@ const NamespaceCacheRefreshInterval = time.Second // newTemporal returns an instance that hosts full temporal in one process func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { impl := &TemporalImpl{ - logger: params.Logger, - clusterMetadataConfig: params.ClusterMetadataConfig, - persistenceConfig: params.PersistenceConfig, - metadataMgr: params.MetadataMgr, - clusterMetadataMgr: params.ClusterMetadataManager, - shardMgr: params.ShardMgr, - taskMgr: params.TaskMgr, - executionManager: params.ExecutionManager, - namespaceReplicationQueue: params.NamespaceReplicationQueue, - abstractDataStoreFactory: params.AbstractDataStoreFactory, - visibilityStoreFactory: params.VisibilityStoreFactory, - esConfig: params.ESConfig, - esClient: params.ESClient, - archiverMetadata: params.ArchiverMetadata, - archiverProvider: params.ArchiverProvider, - frontendConfig: params.FrontendConfig, - historyConfig: params.HistoryConfig, - matchingConfig: params.MatchingConfig, - workerConfig: params.WorkerConfig, - mockAdminClient: params.MockAdminClient, - namespaceReplicationTaskExecutor: params.NamespaceReplicationTaskExecutor, - dcRedirectionPolicy: params.DCRedirectionPolicy, - tlsConfigProvider: params.TLSConfigProvider, - captureMetricsHandler: params.CaptureMetricsHandler, - dcClient: dynamicconfig.NewMemoryClient(), - testHooks: testhooks.NewTestHooks(), - serviceFxOptions: params.ServiceFxOptions, - taskCategoryRegistry: params.TaskCategoryRegistry, - hostsByProtocolByService: params.HostsByProtocolByService, - replicationStreamRecorder: NewReplicationStreamRecorder(), - spanExporters: params.SpanExporters, - tokenProvider: params.TokenProvider, + logger: params.Logger, + config: params.Config, + abstractDataStoreFactory: params.AbstractDataStoreFactory, + visibilityStoreFactory: params.VisibilityStoreFactory, + mockAdminClient: params.MockAdminClient, + tlsConfigProvider: params.TLSConfigProvider, + tokenProvider: params.TokenProvider, + captureMetricsHandler: params.CaptureMetricsHandler, + serverOptions: params.ServerOptions, + dcClient: dynamicconfig.NewMemoryClient(), + testHooks: testhooks.NewTestHooks(), + hostsByProtocolByService: params.HostsByProtocolByService, + replicationStreamRecorder: NewReplicationStreamRecorder(), } - - // Configure output file path for on-demand logging (call WriteToLog() to write) - clusterName := params.ClusterMetadataConfig.CurrentClusterName - outputFile := fmt.Sprintf("/tmp/replication_stream_messages_%s.txt", clusterName) - impl.replicationStreamRecorder.SetOutputFile(outputFile) impl.clients = newClients( impl.logger, impl.hostsByProtocolByService[grpcProtocol], impl.tlsConfigProvider, ) + // Configure output file path for on-demand logging (call WriteToLog() to write) + clusterName := params.Config.ClusterMetadata.CurrentClusterName + outputFile := fmt.Sprintf("/tmp/replication_stream_messages_%s.txt", clusterName) + impl.replicationStreamRecorder.SetOutputFile(outputFile) + // Global defaults: applied without cleanup so they persist across cluster reuse. for k, v := range defaultDynamicConfigOverrides { impl.overrideDynamicConfigForClusterLifetime(k, v) } // Override Nexus callback URL. This is parameterized on the frontend's HTTP address, // so it can't be overriden in the loop above. - impl.setNexusCallbackURL() + if len(impl.hostsByProtocolByService[httpProtocol][primitives.FrontendService].All) > 0 { + impl.setNexusCallbackURL() + } // Per-test overrides: cleaned up when the creating test finishes. for k, v := range params.DynamicConfigOverrides { impl.overrideDynamicConfigForTest(t, k, v) @@ -240,31 +167,28 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { } func (c *TemporalImpl) Start() error { - // create temporal-system namespace, this must be created before starting - // the services - so directly use the metadataManager to create this - if err := c.createSystemNamespace(); err != nil { - return err + for _, serviceName := range []primitives.ServiceName{ + primitives.MatchingService, + primitives.HistoryService, + primitives.FrontendService, + primitives.WorkerService, + } { + for _, host := range c.hostsByProtocolByService[grpcProtocol][serviceName].All { + if err := c.startHost(serviceName, host); err != nil { + return multierr.Combine(err, c.Stop()) + } + } } - c.startMatching() - c.startHistory() - c.startFrontend() - c.startWorker() - return nil } func (c *TemporalImpl) Stop() error { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - var errs []error - errs = append(errs, c.close()...) - - slices.Reverse(c.fxApps) // less log spam if we go backwards - for _, app := range c.fxApps { - errs = append(errs, app.Stop(ctx)) + for i := len(c.servers) - 1; i >= 0; i-- { // less log spam if we go backwards + server := c.servers[i] + errs = append(errs, server.Stop()) } - + errs = append(errs, c.clients.close()...) return multierr.Combine(errs...) } @@ -295,6 +219,30 @@ func (c *TemporalImpl) WorkerGRPCAddress() string { return c.hostsByProtocolByService[grpcProtocol][primitives.WorkerService].All[0] } +func (c *TemporalImpl) AdminClient() adminservice.AdminServiceClient { + return c.clients.AdminClient() +} + +func (c *TemporalImpl) OperatorClient() operatorservice.OperatorServiceClient { + return c.clients.OperatorClient() +} + +func (c *TemporalImpl) FrontendClient() workflowservice.WorkflowServiceClient { + return c.clients.FrontendClient() +} + +func (c *TemporalImpl) HistoryClient() historyservice.HistoryServiceClient { + return c.clients.HistoryClient() +} + +func (c *TemporalImpl) MatchingClient() matchingservice.MatchingServiceClient { + return c.clients.MatchingClient() +} + +func (c *TemporalImpl) SchedulerClient() schedulerpb.SchedulerServiceClient { + return c.clients.SchedulerClient() +} + func (c *TemporalImpl) DcClient() *dynamicconfig.MemoryClient { return c.dcClient } @@ -314,358 +262,220 @@ func (c *TemporalImpl) ChasmVisibilityManager() chasm.VisibilityManager { return c.chasmVisibilityMgr } -func (c *TemporalImpl) copyPersistenceConfig() config.Persistence { - persistenceConfig := copyPersistenceConfig(c.persistenceConfig) - if c.esConfig != nil { - esDataStoreName := "es-visibility" - persistenceConfig.VisibilityStore = esDataStoreName - persistenceConfig.DataStores[esDataStoreName] = config.DataStore{ - Elasticsearch: c.esConfig, - } +func (c *TemporalImpl) startHost(serviceName primitives.ServiceName, host string) error { + logger := log.With(c.logger, tag.Host(host)) + opts := c.serverOptionsForHost(serviceName, host, logger) + + cleanupHooks := c.installHostTestHooks(serviceName) + defer cleanupHooks() + + server, err := temporal.NewServerFx(temporal.TopLevelModule, opts...) + if err != nil { + return fmt.Errorf("unable to construct %s temporal host %s: %w", serviceName, host, err) } - return persistenceConfig -} - -func (c *TemporalImpl) startFrontend() { - serviceName := primitives.FrontendService - - var matchingRawClient resource.MatchingRawClient - var grpcResolver *membership.GRPCResolver - - for _, host := range c.hostsByProtocolByService[grpcProtocol][serviceName].All { - logger := log.With(c.logger, tag.Host(host)) - var namespaceRegistry namespace.Registry - app := fx.New( - fx.Supply( - c.copyPersistenceConfig(), - serviceName, - c.mockAdminClient, - ), - fx.Provide(c.frontendConfigProvider), - fx.Provide(func() listenHostPort { return listenHostPort(host) }), - fx.Provide(func() httpPort { return mustPortFromAddress(c.FrontendHTTPAddress()) }), - fx.Provide(func() config.DCRedirectionPolicy { return c.dcRedirectionPolicy }), - fx.Provide(func() log.Logger { return logger }), - fx.Provide(func() log.ThrottledLogger { return logger }), - fx.Provide(func() resource.NamespaceLogger { return logger }), - fx.Provide(c.newRPCFactory), - static.MembershipModule(c.makeHostMap(serviceName, host)), - fx.Provide(func() *cluster.Config { return c.clusterMetadataConfig }), - fx.Provide(func() carchiver.ArchivalMetadata { return c.archiverMetadata }), - fx.Provide(func() provider.ArchiverProvider { return c.archiverProvider }), - fx.Provide(sdkClientFactoryProvider), - fx.Provide(c.GetMetricsHandler), - fx.Provide(func() []grpc.UnaryServerInterceptor { - if c.replicationStreamRecorder != nil { - return []grpc.UnaryServerInterceptor{ - c.replicationStreamRecorder.UnaryServerInterceptor(c.clusterMetadataConfig.CurrentClusterName), - } - } - return nil - }), - fx.Provide(func() []grpc.StreamServerInterceptor { - if c.replicationStreamRecorder != nil { - return []grpc.StreamServerInterceptor{ - c.replicationStreamRecorder.StreamServerInterceptor(c.clusterMetadataConfig.CurrentClusterName), - } - } - return nil - }), - fx.Provide(func() authorization.Authorizer { return c }), - fx.Provide(func() authorization.ClaimMapper { return c }), - fx.Provide(func() authorization.JWTAudienceMapper { return nil }), - fx.Provide(c.newClientFactoryProvider), - fx.Provide(func() searchattribute.Mapper { return nil }), - // Comment the line above and uncomment the line below to test with search attributes mapper. - // fx.Provide(func() searchattribute.Mapper { return NewSearchAttributeTestMapper() }), - fx.Provide(func() resolver.ServiceResolver { return resolver.NewNoopResolver() }), - fx.Provide(persistenceClient.FactoryProvider), - fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), - fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), - fx.Provide(func() dynamicconfig.Client { return c.dcClient }), - fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), - fx.Provide(resource.DefaultSnTaggedLoggerProvider), - fx.Provide(func() esclient.Client { return c.esClient }), - fx.Provide(c.GetTLSConfigProvider), - fx.Provide(c.GetTaskCategoryRegistry), - temporal.TraceExportModule, - temporal.ServiceTracingModule, - frontend.Module, - fx.Populate(&namespaceRegistry, &grpcResolver, &matchingRawClient), - temporal.FxLogAdapter, - c.getFxOptionsForService(primitives.FrontendService), - chasm.Module, - chasmtests.Module, - ) - err := app.Err() - if err != nil { - logger.Fatal("unable to construct frontend service", tag.Error(err)) - } + if err := server.Start(); err != nil { + return fmt.Errorf("unable to start %s temporal host %s: %w", serviceName, host, err) + } + c.servers = append(c.servers, server) + return nil +} - c.fxApps = append(c.fxApps, app) - c.namespaceRegistries = append(c.namespaceRegistries, namespaceRegistry) - // TODO: create matching client without reaching into fx graph - c.matching.client = matchingRawClient +func (c *TemporalImpl) serverOptionsForHost( + serviceName primitives.ServiceName, + host string, + logger log.Logger, +) []temporal.ServerOption { + options := []temporal.ServerOption{ + temporal.WithConfig(c.configForHost(serviceName, host)), + temporal.ForServices([]string{string(serviceName)}), + temporal.WithStaticHosts(c.makeHostMap(serviceName, host)), + temporal.WithLogger(logger), + temporal.WithNamespaceLogger(logger), + temporal.WithDynamicConfigClient(c.dcClient), + temporal.WithCustomDataStoreFactory(c.abstractDataStoreFactory), + temporal.WithCustomVisibilityStoreFactory(c.visibilityStoreFactory), + temporal.WithClientFactoryProvider(c.newClientFactoryProvider(c.config.ClusterMetadata, c.mockAdminClient)), + temporal.WithTestHooks(c.testHooks), + temporal.WithAuthorizer(c), + temporal.WithClaimMapper(func(*config.Config) authorization.ClaimMapper { return c }), + temporal.WithAudienceGetter(func(*config.Config) authorization.JWTAudienceMapper { return nil }), + temporal.WithSearchAttributesMapper(nil), + temporal.WithPersistenceServiceResolver(resolver.NewNoopResolver()), + temporal.WithCustomMetricsHandler(c.GetMetricsHandler()), + } + if c.tlsConfigProvider != nil { + options = append(options, temporal.WithTLSConfigFactory(c.tlsConfigProvider)) + } + if c.tokenProvider != nil { + options = append(options, temporal.WithTokenProvider(c.tokenProvider)) + } + if serviceName == primitives.FrontendService && c.replicationStreamRecorder != nil { + options = append(options, temporal.WithChainedFrontendGrpcInterceptors( + c.replicationStreamRecorder.UnaryServerInterceptor(c.config.ClusterMetadata.CurrentClusterName), + )) + } + options = append(options, c.serverOptions...) + return options +} - if err := app.Start(context.Background()); err != nil { - logger.Fatal("unable to start frontend service", tag.Error(err)) - } +func (c *TemporalImpl) installHostTestHooks( + serviceName primitives.ServiceName, +) func() { + var cleanups []func() + addCleanup := func(cleanup func()) { + cleanups = append(cleanups, cleanup) } - // Address for SDKs - c.frontendMembershipAddress = grpcResolver.MakeURL(serviceName) -} - -func (c *TemporalImpl) startHistory() { - serviceName := primitives.HistoryService - - for _, host := range c.hostsByProtocolByService[grpcProtocol][serviceName].All { - var namespaceRegistry namespace.Registry - logger := log.With(c.logger, tag.Host(host)) - app := fx.New( - fx.Supply( - c.copyPersistenceConfig(), - serviceName, - c.mockAdminClient, - ), - fx.Provide(c.configProvider), - fx.Provide(c.GetMetricsHandler), - fx.Provide(func() listenHostPort { return listenHostPort(host) }), - fx.Provide(func() httpPort { return mustPortFromAddress(c.FrontendHTTPAddress()) }), - fx.Provide(func() config.DCRedirectionPolicy { return config.DCRedirectionPolicy{} }), - fx.Provide(func() log.Logger { return logger }), - fx.Provide(func() log.ThrottledLogger { return logger }), - fx.Provide(c.newRPCFactory), - fx.Decorate(func(base persistence.ExecutionManager, logger log.Logger) persistence.ExecutionManager { - // Wrap ExecutionManager with recorder to capture task writes - // This wraps the FINAL ExecutionManager after all FX processing (metrics, retries, etc.) - c.taskQueueRecorder = NewTaskQueueRecorder(base, logger) - return c.taskQueueRecorder - }), - fx.Decorate(func(base []grpc.UnaryServerInterceptor) []grpc.UnaryServerInterceptor { + addCleanup(testhooks.Set( + c.testHooks, + testhooks.NamespaceRegistryCreated, + func(name primitives.ServiceName, registry namespace.Registry) { + if name == serviceName { + c.namespaceRegistries = append(c.namespaceRegistries, registry) + } + }, + testhooks.GlobalScope, + )) + addCleanup(testhooks.Set( + c.testHooks, + testhooks.ChasmRegistryInitializer, + func(registry *chasm.Registry) error { + return registry.Register(chasmtests.Library) + }, + testhooks.GlobalScope, + )) + addCleanup(testhooks.Set( + c.testHooks, + testhooks.ServiceClientDialOptions, + func(options map[primitives.ServiceName][]grpc.DialOption) { + dialOptions := c.clientDialOptions() + if len(dialOptions) == 0 { + return + } + for _, serviceName := range []primitives.ServiceName{ + primitives.FrontendService, + primitives.InternalFrontendService, + primitives.HistoryService, + primitives.MatchingService, + } { + options[serviceName] = append(options[serviceName], dialOptions...) + } + }, + testhooks.GlobalScope, + )) + addCleanup(testhooks.Set( + c.testHooks, + testhooks.ServiceGrpcInterceptors, + func(name primitives.ServiceName, unaryInterceptors *[]grpc.UnaryServerInterceptor, streamInterceptors *[]grpc.StreamServerInterceptor) { + switch name { + case primitives.FrontendService: if c.replicationStreamRecorder != nil { - return append(base, c.replicationStreamRecorder.UnaryServerInterceptor(c.clusterMetadataConfig.CurrentClusterName)) + *streamInterceptors = append( + *streamInterceptors, + c.replicationStreamRecorder.StreamServerInterceptor(c.config.ClusterMetadata.CurrentClusterName), + ) } - return base - }), - fx.Provide(func() []grpc.StreamServerInterceptor { + case primitives.HistoryService: if c.replicationStreamRecorder != nil { - return []grpc.StreamServerInterceptor{ - c.replicationStreamRecorder.StreamServerInterceptor(c.clusterMetadataConfig.CurrentClusterName), - } + *unaryInterceptors = append( + *unaryInterceptors, + c.replicationStreamRecorder.UnaryServerInterceptor(c.config.ClusterMetadata.CurrentClusterName), + ) + *streamInterceptors = append( + *streamInterceptors, + c.replicationStreamRecorder.StreamServerInterceptor(c.config.ClusterMetadata.CurrentClusterName), + ) } - return nil - }), - static.MembershipModule(c.makeHostMap(serviceName, host)), - fx.Provide(func() *cluster.Config { return c.clusterMetadataConfig }), - fx.Provide(func() carchiver.ArchivalMetadata { return c.archiverMetadata }), - fx.Provide(func() provider.ArchiverProvider { return c.archiverProvider }), - fx.Provide(sdkClientFactoryProvider), - fx.Provide(c.newClientFactoryProvider), - fx.Provide(func() searchattribute.Mapper { return nil }), - // Comment the line above and uncomment the line below to test with search attributes mapper. - // fx.Provide(func() searchattribute.Mapper { return NewSearchAttributeTestMapper() }), - fx.Provide(func() resolver.ServiceResolver { return resolver.NewNoopResolver() }), - fx.Provide(persistenceClient.FactoryProvider), - fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), - fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), - fx.Provide(func() dynamicconfig.Client { return c.dcClient }), - fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), - fx.Provide(resource.DefaultSnTaggedLoggerProvider), - fx.Provide(func() esclient.Client { return c.esClient }), - fx.Provide(c.GetTLSConfigProvider), - fx.Provide(c.GetTaskCategoryRegistry), - temporal.TraceExportModule, - temporal.ServiceTracingModule, - history.QueueModule, - history.Module, - replication.Module, - temporal.FxLogAdapter, - c.getFxOptionsForService(primitives.HistoryService), - chasm.Module, - chasmtests.Module, - fx.Populate(&namespaceRegistry), - fx.Populate(&c.chasmEngine), - fx.Populate(&c.chasmVisibilityMgr), - fx.Populate(&c.chasmRegistry), - ) - err := app.Err() - if err != nil { - logger.Fatal("unable to construct history service", tag.Error(err)) - } - c.fxApps = append(c.fxApps, app) - c.namespaceRegistries = append(c.namespaceRegistries, namespaceRegistry) - - if err := app.Start(context.Background()); err != nil { - logger.Fatal("unable to start history service", tag.Error(err)) - } + default: + } + }, + testhooks.GlobalScope, + )) + addCleanup(testhooks.Set( + c.testHooks, + testhooks.MatchingRawClientCreated, + func(name primitives.ServiceName, client matchingservice.MatchingServiceClient) { + if name == primitives.FrontendService && c.clients.matching.client == nil { + c.clients.matching.client = client + } + }, + testhooks.GlobalScope, + )) + + switch serviceName { + case primitives.HistoryService: + addCleanup(testhooks.Set( + c.testHooks, + testhooks.HistoryChasmComponentsCreated, + func(refs testhooks.HistoryChasmComponents) { + c.chasmEngine = refs.Engine + c.chasmVisibilityMgr = refs.VisibilityManager + c.chasmRegistry = refs.Registry + }, + testhooks.GlobalScope, + )) + addCleanup(testhooks.Set( + c.testHooks, + testhooks.PersistenceExecutionManagerWrapper, + func(base persistence.ExecutionManager, logger log.Logger) persistence.ExecutionManager { + // Wrap ExecutionManager with recorder to capture task writes + // This wraps the FINAL ExecutionManager after all FX processing (metrics, retries, etc.) + c.taskQueueRecorder = NewTaskQueueRecorder(base, logger) + return c.taskQueueRecorder + }, + testhooks.GlobalScope, + )) + default: } -} -func (c *TemporalImpl) startMatching() { - serviceName := primitives.MatchingService - - for _, host := range c.hostsByProtocolByService[grpcProtocol][serviceName].All { - var namespaceRegistry namespace.Registry - logger := log.With(c.logger, tag.Host(host)) - app := fx.New( - fx.Supply( - c.copyPersistenceConfig(), - serviceName, - c.mockAdminClient, - ), - fx.Provide(c.configProvider), - fx.Provide(c.GetMetricsHandler), - fx.Provide(func() listenHostPort { return listenHostPort(host) }), - fx.Provide(func() httpPort { return mustPortFromAddress(c.FrontendHTTPAddress()) }), - fx.Provide(func() log.Logger { return logger }), - fx.Provide(func() log.ThrottledLogger { return logger }), - fx.Provide(c.newRPCFactory), - static.MembershipModule(c.makeHostMap(serviceName, host)), - fx.Provide(func() *cluster.Config { return c.clusterMetadataConfig }), - fx.Provide(func() carchiver.ArchivalMetadata { return c.archiverMetadata }), - fx.Provide(func() provider.ArchiverProvider { return c.archiverProvider }), - fx.Provide(c.newClientFactoryProvider), - fx.Provide(func() searchattribute.Mapper { return nil }), - fx.Provide(func() resolver.ServiceResolver { return resolver.NewNoopResolver() }), - fx.Provide(persistenceClient.FactoryProvider), - fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), - fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), - fx.Provide(func() dynamicconfig.Client { return c.dcClient }), - fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), - fx.Provide(func() esclient.Client { return c.esClient }), - fx.Provide(c.GetTLSConfigProvider), - fx.Provide(resource.DefaultSnTaggedLoggerProvider), - fx.Provide(c.GetTaskCategoryRegistry), - temporal.TraceExportModule, - temporal.ServiceTracingModule, - matching.Module, - temporal.FxLogAdapter, - c.getFxOptionsForService(primitives.MatchingService), - chasm.Module, - chasmtests.Module, - fx.Populate(&namespaceRegistry), - ) - err := app.Err() - if err != nil { - logger.Fatal("unable to start matching service", tag.Error(err)) - } - c.fxApps = append(c.fxApps, app) - c.namespaceRegistries = append(c.namespaceRegistries, namespaceRegistry) - if err := app.Start(context.Background()); err != nil { - logger.Fatal("unable to start matching service", tag.Error(err)) + return func() { + for i := len(cleanups) - 1; i >= 0; i-- { + cleanups[i]() } } } -func (c *TemporalImpl) startWorker() { - serviceName := primitives.WorkerService - - clusterConfigCopy := cluster.Config{ - EnableGlobalNamespace: c.clusterMetadataConfig.EnableGlobalNamespace, - FailoverVersionIncrement: c.clusterMetadataConfig.FailoverVersionIncrement, - MasterClusterName: c.clusterMetadataConfig.MasterClusterName, - CurrentClusterName: c.clusterMetadataConfig.CurrentClusterName, - ClusterInformation: maps.Clone(c.clusterMetadataConfig.ClusterInformation), - } - - for _, host := range c.hostsByProtocolByService[grpcProtocol][serviceName].All { - var namespaceRegistry namespace.Registry - logger := log.With(c.logger, tag.Host(host)) - app := fx.New( - - fx.Supply( - c.copyPersistenceConfig(), - serviceName, - c.mockAdminClient, - ), - fx.Provide(c.configProvider), - fx.Provide(c.GetMetricsHandler), - fx.Provide(func() listenHostPort { return listenHostPort(host) }), - fx.Provide(func() httpPort { return mustPortFromAddress(c.FrontendHTTPAddress()) }), - fx.Provide(func() config.DCRedirectionPolicy { return config.DCRedirectionPolicy{} }), - fx.Provide(func() log.Logger { return logger }), - fx.Provide(func() log.ThrottledLogger { return logger }), - fx.Provide(c.newRPCFactory), - static.MembershipModule(c.makeHostMap(serviceName, host)), - fx.Provide(func() *cluster.Config { return &clusterConfigCopy }), - fx.Provide(func() carchiver.ArchivalMetadata { return c.archiverMetadata }), - fx.Provide(func() provider.ArchiverProvider { return c.archiverProvider }), - fx.Provide(sdkClientFactoryProvider), - fx.Provide(c.newClientFactoryProvider), - fx.Provide(func() searchattribute.Mapper { return nil }), - fx.Provide(func() resolver.ServiceResolver { return resolver.NewNoopResolver() }), - fx.Provide(persistenceClient.FactoryProvider), - fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), - fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), - fx.Provide(func() dynamicconfig.Client { return c.dcClient }), - fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), - fx.Provide(resource.DefaultSnTaggedLoggerProvider), - fx.Provide(func() esclient.Client { return c.esClient }), - fx.Provide(c.GetTLSConfigProvider), - fx.Provide(c.GetTaskCategoryRegistry), - temporal.TraceExportModule, - temporal.ServiceTracingModule, - worker.Module, - temporal.FxLogAdapter, - c.getFxOptionsForService(primitives.WorkerService), - chasm.Module, - chasmtests.Module, - fx.Populate(&namespaceRegistry), +func (c *TemporalImpl) clientDialOptions() []grpc.DialOption { + var options []grpc.DialOption + if c.replicationStreamRecorder != nil { + options = append(options, + grpc.WithChainUnaryInterceptor(c.replicationStreamRecorder.UnaryInterceptor(c.config.ClusterMetadata.CurrentClusterName)), + grpc.WithChainStreamInterceptor(c.replicationStreamRecorder.StreamInterceptor(c.config.ClusterMetadata.CurrentClusterName)), ) - err := app.Err() - if err != nil { - logger.Fatal("unable to start worker service", tag.Error(err)) - } - - c.fxApps = append(c.fxApps, app) - c.namespaceRegistries = append(c.namespaceRegistries, namespaceRegistry) - if err := app.Start(context.Background()); err != nil { - logger.Fatal("unable to start worker service", tag.Error(err)) - } } + return options } -func (c *TemporalImpl) getFxOptionsForService(serviceName primitives.ServiceName) fx.Option { - return fx.Options(c.serviceFxOptions[serviceName]...) -} - -func (c *TemporalImpl) createSystemNamespace() error { - err := c.metadataMgr.InitializeSystemNamespaces(context.Background(), c.clusterMetadataConfig.CurrentClusterName) - if err != nil { - return fmt.Errorf("failed to create temporal-system namespace: %v", err) +func (c *TemporalImpl) configForHost(serviceName primitives.ServiceName, host string) *config.Config { + bindIP, port := mustSplitHostPort(host) + rpcConfig := config.RPC{ + BindOnIP: bindIP, + GRPCPort: int(port), + } + if serviceName == primitives.FrontendService { + // Set HTTP port and a test HTTP forwarded header + _, httpPort := mustSplitHostPort(c.FrontendHTTPAddress()) + rpcConfig.HTTPPort = int(httpPort) + rpcConfig.HTTPAdditionalForwardedHeaders = []string{ + "this-header-forwarded", + "this-header-prefix-forwarded-*", + } } - return nil -} -func (c *TemporalImpl) GetExecutionManager() persistence.ExecutionManager { - if c.taskQueueRecorder != nil { - return c.taskQueueRecorder + cfg := *c.config + cfg.Persistence = copyPersistenceConfig(c.config.Persistence) + cfg.Services = map[string]config.Service{ + string(serviceName): { + RPC: rpcConfig, + }, } - return c.executionManager + return &cfg } func (c *TemporalImpl) GetTaskQueueRecorder() *TaskQueueRecorder { return c.taskQueueRecorder } -func (c *TemporalImpl) SetTaskQueueRecorder(recorder *TaskQueueRecorder) { - c.taskQueueRecorder = recorder -} - -func (c *TemporalImpl) GetTLSConfigProvider() encryption.TLSConfigProvider { - // If we just return this directly, the interface will be non-nil but the - // pointer will be nil - if c.tlsConfigProvider != nil { - return c.tlsConfigProvider - } - return nil -} - -func (c *TemporalImpl) GetTaskCategoryRegistry() tasks.TaskCategoryRegistry { - return c.taskCategoryRegistry -} - func (c *TemporalImpl) GetCHASMRegistry() *chasm.Registry { return c.chasmRegistry } @@ -687,102 +497,6 @@ func (c *TemporalImpl) GetMetricsHandler() metrics.Handler { return metrics.NoopMetricsHandler } -func (c *TemporalImpl) frontendConfigProvider() *config.Config { - // Set HTTP port and a test HTTP forwarded header - return &config.Config{ - Services: map[string]config.Service{ - string(primitives.FrontendService): { - RPC: config.RPC{ - HTTPPort: int(mustPortFromAddress(c.FrontendHTTPAddress())), - HTTPAdditionalForwardedHeaders: []string{ - "this-header-forwarded", - "this-header-prefix-forwarded-*", - }, - }, - }, - }, - DCRedirectionPolicy: c.dcRedirectionPolicy, - ExporterConfig: telemetry.ExportConfig{ - CustomExporters: c.spanExporters, - }, - } -} - -func (c *TemporalImpl) configProvider(serviceName primitives.ServiceName) *config.Config { - return &config.Config{ - Services: map[string]config.Service{ - string(serviceName): { - RPC: config.RPC{}, - }, - }, - DCRedirectionPolicy: config.DCRedirectionPolicy{}, - ExporterConfig: telemetry.ExportConfig{ - CustomExporters: c.spanExporters, - }, - } -} - -func (c *TemporalImpl) newRPCFactory( - sn primitives.ServiceName, - grpcHostPort listenHostPort, - logger log.Logger, - grpcResolver *membership.GRPCResolver, - tlsConfigProvider encryption.TLSConfigProvider, - monitor membership.Monitor, - tracingStatsHandler telemetry.ClientStatsHandler, - httpPort httpPort, - metricsHandler metrics.Handler, -) (common.RPCFactory, error) { - host, portStr, err := net.SplitHostPort(string(grpcHostPort)) - if err != nil { - return nil, fmt.Errorf("failed parsing host:port: %w", err) - } - port, err := strconv.Atoi(portStr) - if err != nil { - return nil, fmt.Errorf("invalid port: %w", err) - } - var frontendTLSConfig *tls.Config - if tlsConfigProvider != nil { - if frontendTLSConfig, err = tlsConfigProvider.GetFrontendClientConfig(); err != nil { - return nil, fmt.Errorf("failed getting client TLS config: %w", err) - } - } - var options []grpc.DialOption - if tracingStatsHandler != nil { - options = append(options, grpc.WithStatsHandler(tracingStatsHandler)) - } - // Add replication stream recorder injector - if c.replicationStreamRecorder != nil { - options = append(options, - grpc.WithChainUnaryInterceptor(c.replicationStreamRecorder.UnaryInterceptor(c.clusterMetadataConfig.CurrentClusterName)), - grpc.WithChainStreamInterceptor(c.replicationStreamRecorder.StreamInterceptor(c.clusterMetadataConfig.CurrentClusterName)), - ) - } - rpcConfig := config.RPC{BindOnIP: host, GRPCPort: port, HTTPPort: int(httpPort)} - cfg := &config.Config{ - Services: map[string]config.Service{ - string(sn): { - RPC: rpcConfig, - }, - }, - } - return rpc.NewFactory( - cfg, - sn, - logger, - metricsHandler, - tlsConfigProvider, - grpcResolver.MakeURL(primitives.FrontendService), - grpcResolver.MakeURL(primitives.FrontendService), - int(httpPort), - frontendTLSConfig, - options, - resource.PerServiceDialOptionsProvider(logger), - monitor, - c.tokenProvider, - ), nil -} - func (c *TemporalImpl) newClientFactoryProvider( config *cluster.Config, mockAdminClient map[string]adminservice.AdminServiceClient, @@ -912,29 +626,6 @@ func copyPersistenceConfig(cfg config.Persistence) config.Persistence { return newCfg } -func sdkClientFactoryProvider( - grpcResolver *membership.GRPCResolver, - metricsHandler metrics.Handler, - logger log.Logger, - dc *dynamicconfig.Collection, - tlsConfigProvider encryption.TLSConfigProvider, -) sdk.ClientFactory { - var tlsConfig *tls.Config - if tlsConfigProvider != nil { - var err error - if tlsConfig, err = tlsConfigProvider.GetFrontendClientConfig(); err != nil { - panic(err) - } - } - return sdk.NewClientFactory( - grpcResolver.MakeURL(primitives.FrontendService), - tlsConfig, - metricsHandler, - logger, - dynamicconfig.WorkerStickyCacheSize.Get(dc), - ) -} - func (c *TemporalImpl) setNexusCallbackURL() { // Set Nexus callback URL with the cluster's HTTP address. This is a sensible default to avoid // users to need to manually set this. @@ -964,8 +655,8 @@ func (c *TemporalImpl) injectHook(t *testing.T, hook testhooks.Hook, scope any) return cleanup } -func mustPortFromAddress(addr string) httpPort { - _, port, err := net.SplitHostPort(addr) +func mustSplitHostPort(addr string) (string, httpPort) { + host, port, err := net.SplitHostPort(addr) if err != nil { panic(fmt.Errorf("Invalid address: %w", err)) } @@ -973,5 +664,5 @@ func mustPortFromAddress(addr string) httpPort { if err != nil { panic(fmt.Errorf("Cannot parse port: %w", err)) } - return httpPort(portInt) + return host, httpPort(portInt) } diff --git a/tests/testcore/test_cluster.go b/tests/testcore/test_cluster.go index 49c0261ed7e..506a6d40429 100644 --- a/tests/testcore/test_cluster.go +++ b/tests/testcore/test_cluster.go @@ -21,9 +21,6 @@ import ( "go.temporal.io/server/api/matchingservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" schedulerpb "go.temporal.io/server/chasm/lib/scheduler/gen/schedulerpb/v1" - "go.temporal.io/server/common/archiver" - "go.temporal.io/server/common/archiver/filestore" - "go.temporal.io/server/common/archiver/provider" "go.temporal.io/server/common/backoff" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/config" @@ -33,7 +30,6 @@ import ( "go.temporal.io/server/common/membership/static" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/metrics/metricstest" - "go.temporal.io/server/common/namespace/nsreplication" "go.temporal.io/server/common/persistence" persistenceclient "go.temporal.io/server/common/persistence/client" persistencetests "go.temporal.io/server/common/persistence/persistence-tests" @@ -50,7 +46,6 @@ import ( "go.temporal.io/server/temporal" "go.temporal.io/server/temporal/environment" "go.temporal.io/server/tests/testutils" - "go.uber.org/fx" "go.uber.org/multierr" ) @@ -59,45 +54,31 @@ type ( // TestCluster is a testcore struct for functional tests TestCluster struct { - testBase *persistencetests.TestBase - archiverBase *ArchiverBase - host *TemporalImpl - } - - // ArchiverBase is a testcore struct for archiver provider being used in functional tests - ArchiverBase struct { - metadata archiver.ArchivalMetadata - provider provider.ArchiverProvider - historyStoreDirectory string - visibilityStoreDirectory string - historyURI string - visibilityURI string + testBase *persistencetests.TestBase + host *TemporalImpl } // TestClusterConfig are config for a test cluster TestClusterConfig struct { - EnableArchival bool - IsMasterCluster bool - ClusterMetadata cluster.Config - Persistence persistencetests.TestBaseOptions - FrontendConfig FrontendConfig - HistoryConfig HistoryConfig - MatchingConfig MatchingConfig - WorkerConfig WorkerConfig - ESConfig *esclient.Config - MockAdminClient map[string]adminservice.AdminServiceClient - FaultInjection *config.FaultInjection - DCRedirectionPolicy config.DCRedirectionPolicy - DynamicConfigOverrides map[dynamicconfig.Key]any - EnableMTLS bool - EnableMetricsCapture bool - SpanExporters map[telemetry.SpanExporterType]sdktrace.SpanExporter - CustomHistoryArchiverFactory provider.CustomHistoryArchiverFactory - CustomVisibilityArchiverFactory provider.CustomVisibilityArchiverFactory - // ServiceFxOptions can be populated using WithFxOptionsForService. - ServiceFxOptions map[primitives.ServiceName][]fx.Option - TokenProvider auth.TokenProvider - TLSConfigProvider *encryption.FixedTLSConfigProvider + IsMasterCluster bool + ClusterMetadata cluster.Config + Persistence persistencetests.TestBaseOptions + FrontendConfig FrontendConfig + HistoryConfig HistoryConfig + MatchingConfig MatchingConfig + WorkerConfig WorkerConfig + ESConfig *esclient.Config + MockAdminClient map[string]adminservice.AdminServiceClient + FaultInjection *config.FaultInjection + DCRedirectionPolicy config.DCRedirectionPolicy + DynamicConfigOverrides map[dynamicconfig.Key]any + EnableMTLS bool + EnableMetricsCapture bool + SpanExporters map[telemetry.SpanExporterType]sdktrace.SpanExporter + TokenProvider auth.TokenProvider + TLSConfigProvider *encryption.FixedTLSConfigProvider + ServerConfigOverride func(*config.Config) + ServerOptions []temporal.ServerOption } TestClusterFactory interface { @@ -114,22 +95,6 @@ const ( grpcProtocol transferProtocol = "grpc" ) -func (a *ArchiverBase) Metadata() archiver.ArchivalMetadata { - return a.metadata -} - -func (a *ArchiverBase) Provider() provider.ArchiverProvider { - return a.provider -} - -func (a *ArchiverBase) HistoryURI() string { - return a.historyURI -} - -func (a *ArchiverBase) VisibilityURI() string { - return a.visibilityURI -} - func (f *defaultTestClusterFactory) NewCluster(t *testing.T, clusterConfig *TestClusterConfig, logger log.Logger) (*TestCluster, error) { return newClusterWithPersistenceTestBaseFactory(t, clusterConfig, logger, f.tbFactory) } @@ -227,21 +192,11 @@ func newClusterWithPersistenceTestBaseFactory( testBase := tbFactory.NewTestBase(&clusterConfig.Persistence) testBase.Setup(clusterMetadataConfig) - archiverBase := newArchiverBase( - clusterConfig.EnableArchival, - clusterConfig.CustomHistoryArchiverFactory, - clusterConfig.CustomVisibilityArchiverFactory, - testBase.ExecutionManager, - logger, - ) var err error pConfig := testBase.DefaultTestCluster.Config() pConfig.NumHistoryShards = clusterConfig.HistoryConfig.NumHistoryShards - var ( - esClient esclient.Client - ) if !UseSQLVisibility() { clusterConfig.ESConfig = &esclient.Config{ Indices: map[string]string{ @@ -264,10 +219,6 @@ func newClusterWithPersistenceTestBaseFactory( pConfig.DataStores[pConfig.VisibilityStore] = config.DataStore{ Elasticsearch: clusterConfig.ESConfig, } - esClient, err = esclient.NewClient(clusterConfig.ESConfig, nil, logger) - if err != nil { - return nil, err - } } else { clusterConfig.ESConfig = nil } @@ -325,40 +276,34 @@ func newClusterWithPersistenceTestBaseFactory( } } - temporalParams := &TemporalParams{ - ClusterMetadataConfig: clusterMetadataConfig, - PersistenceConfig: pConfig, - MetadataMgr: testBase.MetadataManager, - ClusterMetadataManager: testBase.ClusterMetadataManager, - ShardMgr: testBase.ShardMgr, - ExecutionManager: testBase.ExecutionManager, - NamespaceReplicationQueue: testBase.NamespaceReplicationQueue, - AbstractDataStoreFactory: testBase.AbstractDataStoreFactory, - VisibilityStoreFactory: testBase.VisibilityStoreFactory, - TaskMgr: testBase.TaskMgr, - Logger: logger, - ESConfig: clusterConfig.ESConfig, - ESClient: esClient, - ArchiverMetadata: archiverBase.metadata, - ArchiverProvider: archiverBase.provider, - FrontendConfig: clusterConfig.FrontendConfig, - HistoryConfig: clusterConfig.HistoryConfig, - MatchingConfig: clusterConfig.MatchingConfig, - WorkerConfig: clusterConfig.WorkerConfig, - MockAdminClient: clusterConfig.MockAdminClient, - NamespaceReplicationTaskExecutor: nsreplication.NewTaskExecutor(clusterConfig.ClusterMetadata.CurrentClusterName, testBase.MetadataManager, nsreplication.NewNoopDataMerger(), nsreplication.NewDefaultAdmitter(), logger, testhooks.TestHooks{}), - DCRedirectionPolicy: clusterConfig.DCRedirectionPolicy, - DynamicConfigOverrides: clusterConfig.DynamicConfigOverrides, - TLSConfigProvider: tlsConfigProvider, - ServiceFxOptions: clusterConfig.ServiceFxOptions, - TaskCategoryRegistry: temporal.TaskCategoryRegistryProvider(archiverBase.metadata), - HostsByProtocolByService: hostsByProtocolByService, - SpanExporters: clusterConfig.SpanExporters, - TokenProvider: clusterConfig.TokenProvider, + persistenceConfig := copyPersistenceConfig(pConfig) + if clusterConfig.ESConfig != nil { + esDataStoreName := "es-visibility" + persistenceConfig.VisibilityStore = esDataStoreName + persistenceConfig.DataStores[esDataStoreName] = config.DataStore{ + Elasticsearch: clusterConfig.ESConfig, + } } - + serverConfig := &config.Config{ + Global: config.Global{ + Membership: config.Membership{ + MaxJoinDuration: time.Second, + }, + }, + Persistence: persistenceConfig, + ClusterMetadata: clusterMetadataConfig, + DCRedirectionPolicy: clusterConfig.DCRedirectionPolicy, + Visibility: config.Visibility{}, + ExporterConfig: telemetry.ExportConfig{ + CustomExporters: clusterConfig.SpanExporters, + }, + } + if clusterConfig.ServerConfigOverride != nil { + clusterConfig.ServerConfigOverride(serverConfig) + } + var captureMetricsHandler *metricstest.CaptureHandler if clusterConfig.EnableMetricsCapture { - temporalParams.CaptureMetricsHandler = metricstest.NewCaptureHandler() + captureMetricsHandler = metricstest.NewCaptureHandler() } err = newPProfInitializerImpl(logger, PprofTestPort).Start() @@ -366,12 +311,24 @@ func newClusterWithPersistenceTestBaseFactory( logger.Fatal("Failed to start pprof", tag.Error(err)) } - cluster := newTemporal(t, temporalParams) - if err = cluster.Start(); err != nil { + testCluster := newTemporal(t, &TemporalParams{ + Config: serverConfig, + AbstractDataStoreFactory: testBase.AbstractDataStoreFactory, + VisibilityStoreFactory: testBase.VisibilityStoreFactory, + Logger: logger, + MockAdminClient: clusterConfig.MockAdminClient, + DynamicConfigOverrides: clusterConfig.DynamicConfigOverrides, + TLSConfigProvider: tlsConfigProvider, + HostsByProtocolByService: hostsByProtocolByService, + TokenProvider: clusterConfig.TokenProvider, + CaptureMetricsHandler: captureMetricsHandler, + ServerOptions: clusterConfig.ServerOptions, + }) + if err = testCluster.Start(); err != nil { return nil, err } - return &TestCluster{testBase: testBase, archiverBase: archiverBase, host: cluster}, nil + return &TestCluster{testBase: testBase, host: testCluster}, nil } func setupIndex(esConfig *esclient.Config, logger log.Logger) error { @@ -486,80 +443,17 @@ func newPProfInitializerImpl(logger log.Logger, port int) *pprof.PProfInitialize } } -func newArchiverBase( - enabled bool, - customHistoryArchiverFactory provider.CustomHistoryArchiverFactory, - customVisibilityArchiverFactory provider.CustomVisibilityArchiverFactory, - executionManager persistence.ExecutionManager, - logger log.Logger, -) *ArchiverBase { - dcCollection := dynamicconfig.NewNoopCollection() - if !enabled { - return &ArchiverBase{ - metadata: archiver.NewArchivalMetadata(dcCollection, "", false, "", false, &config.ArchivalNamespaceDefaults{}), - provider: provider.NewArchiverProvider(nil, nil, nil, nil, nil, logger, metrics.NoopMetricsHandler), - } - } - - historyStoreDirectory, err := os.MkdirTemp("", "test-history-archival") - if err != nil { - logger.Fatal("Failed to create temp dir for history archival", tag.Error(err)) - } - visibilityStoreDirectory, err := os.MkdirTemp("", "test-visibility-archival") - if err != nil { - logger.Fatal("Failed to create temp dir for visibility archival", tag.Error(err)) - } - cfg := &config.FilestoreArchiver{ - FileMode: "0666", - DirMode: "0766", - } - provider := provider.NewArchiverProvider( - &config.HistoryArchiverProvider{ - Filestore: cfg, - }, - &config.VisibilityArchiverProvider{ - Filestore: cfg, - }, - customHistoryArchiverFactory, - customVisibilityArchiverFactory, - executionManager, - logger, - metrics.NoopMetricsHandler, - ) - return &ArchiverBase{ - metadata: archiver.NewArchivalMetadata(dcCollection, "enabled", true, "enabled", true, &config.ArchivalNamespaceDefaults{ - History: config.HistoryArchivalNamespaceDefaults{ - State: "enabled", - URI: "testScheme://test/history/archive/path", - }, - Visibility: config.VisibilityArchivalNamespaceDefaults{ - State: "enabled", - URI: "testScheme://test/visibility/archive/path", - }, - }), - provider: provider, - historyStoreDirectory: historyStoreDirectory, - visibilityStoreDirectory: visibilityStoreDirectory, - historyURI: filestore.URIScheme + "://" + historyStoreDirectory, - visibilityURI: filestore.URIScheme + "://" + visibilityStoreDirectory, - } -} - // TearDownCluster tears down the test cluster func (tc *TestCluster) TearDownCluster() error { errs := tc.host.Stop() tc.testBase.TearDownWorkflowStore() - if !UseSQLVisibility() && tc.host.esConfig != nil { - if err := deleteIndex(tc.host.esConfig, tc.host.logger); err != nil { - errs = multierr.Combine(errs, err) + if !UseSQLVisibility() { + if esConfig := tc.host.config.Persistence.DataStores[tc.host.config.Persistence.VisibilityStore].Elasticsearch; esConfig != nil { + if err := deleteIndex(esConfig, tc.host.logger); err != nil { + errs = multierr.Combine(errs, err) + } } } - if err := os.RemoveAll(tc.archiverBase.historyStoreDirectory); err != nil { - errs = multierr.Combine(errs, err) - } - if err := os.RemoveAll(tc.archiverBase.visibilityStoreDirectory); err != nil { - errs = multierr.Combine(errs, err) - } return errs } @@ -568,10 +462,6 @@ func (tc *TestCluster) TestBase() *persistencetests.TestBase { return tc.testBase } -func (tc *TestCluster) ArchiverBase() *ArchiverBase { - return tc.archiverBase -} - func (tc *TestCluster) FrontendClient() workflowservice.WorkflowServiceClient { return tc.host.FrontendClient() } @@ -601,7 +491,10 @@ func (tc *TestCluster) SchedulerClient() schedulerpb.SchedulerServiceClient { // ExecutionManager returns an execution manager factory from the test cluster func (tc *TestCluster) ExecutionManager() persistence.ExecutionManager { - return tc.host.GetExecutionManager() + if tc.host.taskQueueRecorder != nil { + return tc.host.taskQueueRecorder + } + return tc.testBase.ExecutionManager } // TODO (alex): expose only needed objects from TemporalImpl. @@ -618,7 +511,7 @@ func (tc *TestCluster) WorkerGRPCAddress() string { } func (tc *TestCluster) ClusterName() string { - return tc.host.clusterMetadataConfig.CurrentClusterName + return tc.host.config.ClusterMetadata.CurrentClusterName } func (tc *TestCluster) GetReplicationStreamRecorder() *ReplicationStreamRecorder { diff --git a/tests/testcore/test_cluster_pool_test.go b/tests/testcore/test_cluster_pool_test.go index 2243ed81385..dcd848739f0 100644 --- a/tests/testcore/test_cluster_pool_test.go +++ b/tests/testcore/test_cluster_pool_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/require" "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/config" "go.temporal.io/server/common/dynamicconfig" ) @@ -13,7 +14,9 @@ func TestGlobalOverridesSurviveTestCleanup(t *testing.T) { t.Run("create", func(t *testing.T) { impl := newTemporal(t, &TemporalParams{ - ClusterMetadataConfig: &cluster.Config{}, + Config: &config.Config{ + ClusterMetadata: &cluster.Config{}, + }, }) dcClient = impl.dcClient }) diff --git a/tests/testcore/test_env.go b/tests/testcore/test_env.go index 0aaa7255fee..a14c9ee0e6c 100644 --- a/tests/testcore/test_env.go +++ b/tests/testcore/test_env.go @@ -26,12 +26,10 @@ import ( "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/namespace" - "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/testing/taskpoller" "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/testlogger" "go.temporal.io/server/common/testing/testvars" - "go.uber.org/fx" ) // shardSalt is used to distribute functional tests across shards. @@ -46,9 +44,7 @@ var ( ) type Env interface { - // T returns the *testing.T. - // - // Deprecated: use the suite's T() method instead. + // T returns the *testing.T. Deprecated: use the suite's T() method instead. T() *testing.T Namespace() namespace.Name NamespaceID() namespace.ID @@ -132,24 +128,6 @@ func WithSdkWorker() TestOption { } } -// WithTestVars customizes the default test variables for the environment. -func WithTestVars(fn func(*testvars.TestVars) *testvars.TestVars) TestOption { - return func(o *testOptions) { - o.testVars = fn - } -} - -// WithFxOptions appends fx options to a specific service's fx graph. This -// implies a dedicated cluster because custom fx options cannot be shared -// across tests. -func WithFxOptions(serviceName primitives.ServiceName, opts ...fx.Option) TestOption { - return func(o *testOptions) { - o.dedicatedCluster = true - o.clusterOptions = append(o.clusterOptions, WithFxOptionsForService(serviceName, opts...)) - o.dedicatedReason = "custom fx options used" - } -} - // WithWorkerService enables the system worker service. The service is off by // default to avoid the worker overhead. This implies a dedicated cluster. func WithWorkerService(reason string) TestOption { @@ -160,13 +138,10 @@ func WithWorkerService(reason string) TestOption { } } -// WithMTLS enables mutual TLS on the test's cluster. This implies a dedicated -// cluster, since the TLS configuration cannot be shared across tests. -func WithMTLS() TestOption { +func WithClusterOptions(options ...TestClusterOption) TestOption { return func(o *testOptions) { o.dedicatedCluster = true - o.clusterOptions = append(o.clusterOptions, withMTLS()) - o.dedicatedReason = "mTLS enabled" + o.clusterOptions = append(o.clusterOptions, options...) } } @@ -215,6 +190,12 @@ func WithDynamicConfig(setting dynamicconfig.GenericSetting, value any) TestOpti } } +func WithTestVars(fn func(*testvars.TestVars) *testvars.TestVars) TestOption { + return func(o *testOptions) { + o.testVars = fn + } +} + // NewEnv creates a new test environment with access to a Temporal cluster. func NewEnv(t *testing.T, opts ...TestOption) *TestEnv { t.Helper() @@ -261,11 +242,6 @@ func NewEnv(t *testing.T, opts ...TestOption) *TestEnv { t.Fatalf("Failed to register namespace: %v", err) } - tv := testvars.New(t) - if options.testVars != nil { - tv = options.testVars(tv) - } - env := &TestEnv{ FunctionalTestBase: base, Assertions: require.New(t), @@ -275,7 +251,7 @@ func NewEnv(t *testing.T, opts ...TestOption) *TestEnv { Logger: base.Logger, taskPoller: taskpoller.New(t, cluster.FrontendClient(), ns.String()), t: t, - tv: tv, + tv: testvars.New(t), ctx: setupTestTimeoutWithContext(t), sdkWorkerTQ: RandomizeStr("tq-" + t.Name()), dedicatedGuard: dedicatedGuard, @@ -301,6 +277,9 @@ func NewEnv(t *testing.T, opts ...TestOption) *TestEnv { env.OverrideDynamicConfig(override.setting, override.value) } } + if options.testVars != nil { + env.tv = options.testVars(env.tv) + } return env } @@ -367,7 +346,6 @@ func (e *TestEnv) TaskPoller() *taskpoller.TaskPoller { } // NoError asserts that err is nil. -// // Deprecated: use require.NoError with the parent test or suite instead. // TODO: remove once all tests are migrated to TestEnv (and no longer use FunctionalTestBase directly). func (e *TestEnv) NoError(err error, msgAndArgs ...any) { @@ -375,7 +353,6 @@ func (e *TestEnv) NoError(err error, msgAndArgs ...any) { } // Error asserts that err is not nil. -// // Deprecated: use require.Error with the parent test or suite instead. // TODO: remove once all tests are migrated to TestEnv (and no longer use FunctionalTestBase directly). func (e *TestEnv) Error(err error, msgAndArgs ...any) { @@ -383,16 +360,13 @@ func (e *TestEnv) Error(err error, msgAndArgs ...any) { } // Run executes a subtest. -// // Deprecated: use the suite's Run method instead. // TODO: remove once all tests are migrated to TestEnv (and no longer use FunctionalTestBase directly). func (e *TestEnv) Run(name string, subtest func()) bool { return e.FunctionalTestBase.Run(name, subtest) } -// T returns the *testing.T. -// -// Deprecated: use the suite's T() method instead. +// T returns the *testing.T. Deprecated: use the suite's T() method instead. func (e *TestEnv) T() *testing.T { return e.t } diff --git a/tests/tls_test.go b/tests/tls_test.go index 01bcb2993d9..54a415f48d1 100644 --- a/tests/tls_test.go +++ b/tests/tls_test.go @@ -22,7 +22,7 @@ func TestTLSFunctionalSuite(t *testing.T) { func (s *TLSFunctionalSuite) newTestEnv(opts ...testcore.TestOption) *testcore.TestEnv { baseOpts := []testcore.TestOption{ - testcore.WithMTLS(), + testcore.WithClusterOptions(testcore.WithMTLS()), } return testcore.NewEnv(s.T(), append(baseOpts, opts...)...) } diff --git a/tests/xdc/base.go b/tests/xdc/base.go index ddcc1777fd3..835306f6182 100644 --- a/tests/xdc/base.go +++ b/tests/xdc/base.go @@ -144,7 +144,6 @@ func (s *xdcBaseSuite) setupSuite(opts ...testcore.TestClusterOption) { // RPCAddress and HTTPAddress will be filled in }, } - clusterConfigs[clusterIndex].ServiceFxOptions = params.ServiceOptions clusterConfigs[clusterIndex].EnableMetricsCapture = true var err error From 763bbc83d3ed761fa239fb66cef1731f5bcc7e4e Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 20:08:50 -0700 Subject: [PATCH 02/16] Remove onebox namespace registry hook --- common/resource/fx.go | 4 - common/testing/testhooks/hooks.go | 1 - tests/testcore/clients.go | 4 + tests/testcore/namespace.go | 76 +++++++++++++++++ tests/testcore/onebox.go | 21 +---- tests/xdc/base.go | 134 ++++++++++++++++++------------ 6 files changed, 162 insertions(+), 78 deletions(-) create mode 100644 tests/testcore/namespace.go diff --git a/common/resource/fx.go b/common/resource/fx.go index d7a0a929bab..7a64bfd78a5 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -233,7 +233,6 @@ type NamespaceRegistryParams struct { DynamicCollection *dynamicconfig.Collection ReplicationResolverFactory namespace.ReplicationResolverFactory NamespaceStateChangedFn namespace.NamespaceStateChangedFn - TestHooks testhooks.TestHooks `optional:"true"` } func NamespaceRegistryProvider(params NamespaceRegistryParams) namespace.Registry { @@ -248,9 +247,6 @@ func NamespaceRegistryProvider(params NamespaceRegistryParams) namespace.Registr params.ReplicationResolverFactory, params.NamespaceStateChangedFn, ) - if hook, ok := testhooks.Get(params.TestHooks, testhooks.NamespaceRegistryCreated, testhooks.GlobalScope); ok { - hook(params.ServiceName, registry) - } return registry } diff --git a/common/testing/testhooks/hooks.go b/common/testing/testhooks/hooks.go index 658ad139077..35c6e49336f 100644 --- a/common/testing/testhooks/hooks.go +++ b/common/testing/testhooks/hooks.go @@ -44,7 +44,6 @@ var ( NamespaceReplicationTaskInterceptor = newKey[func(context.Context, *replicationspb.NamespaceTaskAttributes, func() error) error, namespace.Name]() ServiceGrpcInterceptors = newKey[func(primitives.ServiceName, *[]grpc.UnaryServerInterceptor, *[]grpc.StreamServerInterceptor), global]() ServiceClientDialOptions = newKey[func(map[primitives.ServiceName][]grpc.DialOption), global]() - NamespaceRegistryCreated = newKey[func(primitives.ServiceName, namespace.Registry), global]() MatchingRawClientCreated = newKey[func(primitives.ServiceName, matchingservice.MatchingServiceClient), global]() ChasmRegistryInitializer = newKey[func(*chasm.Registry) error, global]() HistoryChasmComponentsCreated = newKey[func(HistoryChasmComponents), global]() diff --git a/tests/testcore/clients.go b/tests/testcore/clients.go index 179ea93f8cd..923b20e0526 100644 --- a/tests/testcore/clients.go +++ b/tests/testcore/clients.go @@ -139,6 +139,10 @@ func (c *clients) newConn(serviceName primitives.ServiceName) (*grpc.ClientConn, if err != nil { return nil, err } + return c.newConnToAddress(serviceName, address) +} + +func (c *clients) newConnToAddress(serviceName primitives.ServiceName, address string) (*grpc.ClientConn, error) { tlsConfig, err := c.tlsConfig(serviceName) if err != nil { return nil, err diff --git a/tests/testcore/namespace.go b/tests/testcore/namespace.go new file mode 100644 index 00000000000..fa67c97a3d7 --- /dev/null +++ b/tests/testcore/namespace.go @@ -0,0 +1,76 @@ +package testcore + +import ( + "context" + "errors" + "fmt" + "time" + + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/server/common/primitives" +) + +type NamespaceAvailabilityCheck func(*workflowservice.DescribeNamespaceResponse) error + +func (tc *TestCluster) WaitForNamespaceAvailable( + ctx context.Context, + namespace string, + waitTime time.Duration, + checkInterval time.Duration, + check NamespaceAvailabilityCheck, +) error { + deadline := time.NewTimer(waitTime) + defer deadline.Stop() + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + var lastErr error + for { + lastErr = tc.checkNamespaceAvailable(ctx, namespace, check) + if lastErr == nil { + return nil + } + + select { + case <-ctx.Done(): + return fmt.Errorf("namespace %q did not become available before context deadline: %w", namespace, errors.Join(ctx.Err(), lastErr)) + case <-deadline.C: + return fmt.Errorf("namespace %q did not become available before deadline: %w", namespace, lastErr) + case <-ticker.C: + } + } +} + +func (tc *TestCluster) checkNamespaceAvailable( + ctx context.Context, + namespace string, + check NamespaceAvailabilityCheck, +) error { + hosts := tc.host.hostsByProtocolByService[grpcProtocol][primitives.FrontendService].All + if len(hosts) == 0 { + return fmt.Errorf("no frontend gRPC hosts configured") + } + + var errs []error + for _, host := range hosts { + conn, err := tc.host.clients.newConnToAddress(primitives.FrontendService, host) + if err != nil { + errs = append(errs, fmt.Errorf("dial frontend %s: %w", host, err)) + continue + } + client := workflowservice.NewWorkflowServiceClient(conn) + resp, err := client.DescribeNamespace(NewContext(ctx), &workflowservice.DescribeNamespaceRequest{ + Namespace: namespace, + }) + if err == nil && check != nil { + err = check(resp) + } + if closeErr := conn.Close(); closeErr != nil { + err = errors.Join(err, fmt.Errorf("close frontend %s: %w", host, closeErr)) + } + if err != nil { + errs = append(errs, fmt.Errorf("describe namespace on frontend %s: %w", host, err)) + } + } + return errors.Join(errs...) +} diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 3ac1b720a66..0da9d05262d 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -33,7 +33,6 @@ import ( "go.temporal.io/server/common/membership/static" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/metrics/metricstest" - "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" persistenceClient "go.temporal.io/server/common/persistence/client" "go.temporal.io/server/common/persistence/visibility" @@ -50,10 +49,8 @@ import ( type ( TemporalImpl struct { - // This is used to wait for namespace registries to have noticed a change in some xdc tests. - namespaceRegistries []namespace.Registry - chasmEngine chasm.Engine - chasmVisibilityMgr chasm.VisibilityManager + chasmEngine chasm.Engine + chasmVisibilityMgr chasm.VisibilityManager clients clients @@ -247,10 +244,6 @@ func (c *TemporalImpl) DcClient() *dynamicconfig.MemoryClient { return c.dcClient } -func (c *TemporalImpl) NamespaceRegistries() []namespace.Registry { - return c.namespaceRegistries -} - func (c *TemporalImpl) ChasmEngine() (chasm.Engine, error) { if numHistoryHosts := len(c.hostsByProtocolByService[grpcProtocol][primitives.HistoryService].All); numHistoryHosts != 1 { return nil, fmt.Errorf("expected exactly one host for chasm engine, got %d", numHistoryHosts) @@ -326,16 +319,6 @@ func (c *TemporalImpl) installHostTestHooks( cleanups = append(cleanups, cleanup) } - addCleanup(testhooks.Set( - c.testHooks, - testhooks.NamespaceRegistryCreated, - func(name primitives.ServiceName, registry namespace.Registry) { - if name == serviceName { - c.namespaceRegistries = append(c.namespaceRegistries, registry) - } - }, - testhooks.GlobalScope, - )) addCleanup(testhooks.Set( c.testHooks, testhooks.ChasmRegistryInitializer, diff --git a/tests/xdc/base.go b/tests/xdc/base.go index 835306f6182..4f445c4a64d 100644 --- a/tests/xdc/base.go +++ b/tests/xdc/base.go @@ -4,6 +4,8 @@ import ( "cmp" "context" "errors" + "fmt" + "slices" "sync" "time" @@ -24,7 +26,6 @@ import ( "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" - "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/testing/historyrequire" "go.temporal.io/server/common/testing/protorequire" @@ -301,29 +302,24 @@ func (s *xdcBaseSuite) createNamespace( _, err := clusters[0].FrontendClient().RegisterNamespace(ctx, regReq) s.NoError(err) - s.EventuallyWithT(func(t *assert.CollectT) { - for _, r := range clusters[0].Host().NamespaceRegistries() { - resp, err := r.GetNamespace(namespace.Name(ns)) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, isGlobal, resp.IsGlobalNamespace()) + s.waitForNamespaceAvailable(clusters[0], ns, namespaceCacheWaitTime, func(resp *workflowservice.DescribeNamespaceResponse) error { + if resp.GetIsGlobalNamespace() != isGlobal { + return fmt.Errorf("namespace global state = %v, want %v", resp.GetIsGlobalNamespace(), isGlobal) } - }, namespaceCacheWaitTime, namespaceCacheCheckInterval) + return nil + }) if len(clusters) > 1 && isGlobal { // If namespace is global and config has more than 1 cluster, it should be replicated to these other clusters. // Check other clusters too. - s.EventuallyWithT(func(t *assert.CollectT) { - for _, c := range clusters[1:] { - for _, r := range c.Host().NamespaceRegistries() { - resp, err := r.GetNamespace(namespace.Name(ns)) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, isGlobal, resp.IsGlobalNamespace()) - require.Equal(t, clusterNames, resp.ClusterNames(namespace.EmptyBusinessID)) + for _, c := range clusters[1:] { + s.waitForNamespaceAvailable(c, ns, replicationWaitTime, func(resp *workflowservice.DescribeNamespaceResponse) error { + if resp.GetIsGlobalNamespace() != isGlobal { + return fmt.Errorf("namespace global state = %v, want %v", resp.GetIsGlobalNamespace(), isGlobal) } - } - }, replicationWaitTime, replicationCheckInterval) + return compareNamespaceClusters(resp, clusterNames) + }) + } } return ns @@ -350,33 +346,27 @@ func (s *xdcBaseSuite) updateNamespaceClusters( s.NoError(err) var isGlobalNamespace bool - s.EventuallyWithT(func(t *assert.CollectT) { - for _, r := range clusters[inClusterIndex].Host().NamespaceRegistries() { - resp, err := r.GetNamespace(namespace.Name(ns)) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, clusterNames, resp.ClusterNames(namespace.EmptyBusinessID)) - isGlobalNamespace = resp.IsGlobalNamespace() + s.waitForNamespaceAvailable(clusters[inClusterIndex], ns, namespaceCacheWaitTime, func(resp *workflowservice.DescribeNamespaceResponse) error { + if err := compareNamespaceClusters(resp, clusterNames); err != nil { + return err } - }, namespaceCacheWaitTime, namespaceCacheCheckInterval) + isGlobalNamespace = resp.GetIsGlobalNamespace() + return nil + }) if len(clusters) > 1 && isGlobalNamespace { // If namespace is global and config has more than 1 cluster, it should be replicated to these other clusters. // Check other clusters too. - s.EventuallyWithT(func(t *assert.CollectT) { - for ci, c := range clusters { - if ci == inClusterIndex { - continue - } - for _, r := range c.Host().NamespaceRegistries() { - resp, err := r.GetNamespace(namespace.Name(ns)) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, clusterNames, resp.ClusterNames(namespace.EmptyBusinessID)) - } + for ci, c := range clusters { + if ci == inClusterIndex { + continue } - }, replicationWaitTime, replicationCheckInterval) + s.waitForNamespaceAvailable(c, ns, replicationWaitTime, func(resp *workflowservice.DescribeNamespaceResponse) error { + return compareNamespaceClusters(resp, clusterNames) + }) + } } + s.waitForNamespaceCacheRefresh() } func (s *xdcBaseSuite) promoteNamespace( @@ -390,14 +380,13 @@ func (s *xdcBaseSuite) promoteNamespace( }) s.NoError(err) - s.EventuallyWithT(func(t *assert.CollectT) { - for _, r := range s.clusters[inClusterIndex].Host().NamespaceRegistries() { - resp, err := r.GetNamespace(namespace.Name(ns)) - require.NoError(t, err) - require.NotNil(t, resp) - require.True(t, resp.IsGlobalNamespace()) + s.waitForNamespaceAvailable(s.clusters[inClusterIndex], ns, namespaceCacheWaitTime, func(resp *workflowservice.DescribeNamespaceResponse) error { + if !resp.GetIsGlobalNamespace() { + return fmt.Errorf("namespace is not global") } - }, namespaceCacheWaitTime, namespaceCacheCheckInterval) + return nil + }) + s.waitForNamespaceCacheRefresh() } func (s *xdcBaseSuite) failover( @@ -421,20 +410,57 @@ func (s *xdcBaseSuite) failover( s.Equal(targetFailoverVersion, updateResp.GetFailoverVersion()) // check local and remote clusters - s.EventuallyWithT(func(t *assert.CollectT) { - for _, c := range s.clusters { - for _, r := range c.Host().NamespaceRegistries() { - resp, err := r.GetNamespace(namespace.Name(ns)) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, targetCluster, resp.ActiveClusterName(namespace.RoutingKey{})) + for _, c := range s.clusters { + s.waitForNamespaceAvailable(c, ns, replicationWaitTime, func(resp *workflowservice.DescribeNamespaceResponse) error { + if got := resp.GetReplicationConfig().GetActiveClusterName(); got != targetCluster { + return fmt.Errorf("active cluster = %q, want %q", got, targetCluster) } - } - }, replicationWaitTime, replicationCheckInterval) + return nil + }) + } + s.waitForNamespaceCacheRefresh() s.waitForClusterSynced() } +func (s *xdcBaseSuite) waitForNamespaceAvailable( + cluster *testcore.TestCluster, + ns string, + waitTime time.Duration, + check testcore.NamespaceAvailabilityCheck, +) { + s.Require().NoError(cluster.WaitForNamespaceAvailable( + testcore.NewContext(), + ns, + waitTime, + namespaceCacheCheckInterval, + check, + )) +} + +func compareNamespaceClusters(resp *workflowservice.DescribeNamespaceResponse, want []string) error { + got := make([]string, 0, len(resp.GetReplicationConfig().GetClusters())) + for _, cluster := range resp.GetReplicationConfig().GetClusters() { + got = append(got, cluster.GetClusterName()) + } + if !slices.Equal(got, want) { + return fmt.Errorf("namespace clusters = %v, want %v", got, want) + } + return nil +} + +func (s *xdcBaseSuite) waitForNamespaceCacheRefresh() { + ctx := testcore.NewContext() + timer := time.NewTimer(namespaceCacheWaitTime) + defer timer.Stop() + + select { + case <-timer.C: + case <-ctx.Done(): + s.Require().NoError(ctx.Err()) + } +} + func (s *xdcBaseSuite) newClientAndWorker(hostport, ns, taskqueue, identity string) (sdkclient.Client, sdkworker.Worker) { sdkClient, err := sdkclient.Dial(sdkclient.Options{ HostPort: hostport, From eabcaf77211c47fe065f741ce8bf3dc68a06a4ab Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 20:18:09 -0700 Subject: [PATCH 03/16] Narrow replication stream recording hook --- common/resource/fx.go | 4 - common/testing/testhooks/hooks.go | 23 ++- service/frontend/fx.go | 5 - service/fx.go | 6 - service/history/handler.go | 1 + .../history/replication/stream_observer.go | 39 ++++ .../history/replication/stream_receiver.go | 25 ++- service/history/replication/stream_sender.go | 24 +++ .../history/replication/stream_sender_test.go | 2 + tests/testcore/onebox.go | 70 +------ tests/testcore/replication_stream_recorder.go | 191 ++---------------- 11 files changed, 130 insertions(+), 260 deletions(-) create mode 100644 service/history/replication/stream_observer.go diff --git a/common/resource/fx.go b/common/resource/fx.go index 7a64bfd78a5..b8a9d99e12e 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -414,7 +414,6 @@ func DCRedirectionPolicyProvider(cfg *config.Config) config.DCRedirectionPolicy func PerServiceDialOptionsProvider( logger log.SnTaggedLogger, - testHooks testhooks.TestHooks, ) map[primitives.ServiceName][]grpc.DialOption { trailerInterceptor := interceptor.TrailerToContextMetadataInterceptor(logger) dialOpt := grpc.WithChainUnaryInterceptor(trailerInterceptor) @@ -422,9 +421,6 @@ func PerServiceDialOptionsProvider( primitives.HistoryService: {dialOpt}, primitives.MatchingService: {dialOpt}, } - if hook, ok := testhooks.Get(testHooks, testhooks.ServiceClientDialOptions, testhooks.GlobalScope); ok { - hook(options) - } return options } diff --git a/common/testing/testhooks/hooks.go b/common/testing/testhooks/hooks.go index 35c6e49336f..db95096b493 100644 --- a/common/testing/testhooks/hooks.go +++ b/common/testing/testhooks/hooks.go @@ -14,7 +14,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/primitives" historytasks "go.temporal.io/server/service/history/tasks" - "google.golang.org/grpc" + "google.golang.org/protobuf/proto" ) type ( @@ -23,6 +23,24 @@ type ( VisibilityManager chasm.VisibilityManager Registry *chasm.Registry } + + ReplicationStreamMessageDirection string + + ReplicationStreamMessage struct { + Method string + Direction ReplicationStreamMessageDirection + ClusterName string + TargetAddress string + Message proto.Message + IsStreamCall bool + } +) + +const ( + ReplicationStreamDirectionSend ReplicationStreamMessageDirection = "send" + ReplicationStreamDirectionRecv ReplicationStreamMessageDirection = "recv" + ReplicationStreamDirectionServerSend ReplicationStreamMessageDirection = "server_send" + ReplicationStreamDirectionServerRecv ReplicationStreamMessageDirection = "server_recv" ) // Test hook keys with their return type and scope. @@ -42,8 +60,7 @@ var ( HistoryTransferTaskInterceptor = newKey[func(historytasks.Task, func()), namespace.ID]() HistoryDLQTaskDeleteInterceptor = newKey[func(context.Context, *historyservice.DeleteDLQTasksRequest, func(context.Context, *historyservice.DeleteDLQTasksRequest) (*historyservice.DeleteDLQTasksResponse, error)) (*historyservice.DeleteDLQTasksResponse, error), global]() NamespaceReplicationTaskInterceptor = newKey[func(context.Context, *replicationspb.NamespaceTaskAttributes, func() error) error, namespace.Name]() - ServiceGrpcInterceptors = newKey[func(primitives.ServiceName, *[]grpc.UnaryServerInterceptor, *[]grpc.StreamServerInterceptor), global]() - ServiceClientDialOptions = newKey[func(map[primitives.ServiceName][]grpc.DialOption), global]() + ReplicationStreamMessageObserver = newKey[func(ReplicationStreamMessage), global]() MatchingRawClientCreated = newKey[func(primitives.ServiceName, matchingservice.MatchingServiceClient), global]() ChasmRegistryInitializer = newKey[func(*chasm.Registry) error, global]() HistoryChasmComponentsCreated = newKey[func(HistoryChasmComponents), global]() diff --git a/service/frontend/fx.go b/service/frontend/fx.go index c932713dda9..53b7094ceb9 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -246,12 +246,7 @@ func GrpcServerOptionsProvider( customInterceptors []grpc.UnaryServerInterceptor, customStreamInterceptors []grpc.StreamServerInterceptor, metricsHandler metrics.Handler, - testHooks testhooks.TestHooks, ) GrpcServerOptions { - if hook, ok := testhooks.Get(testHooks, testhooks.ServiceGrpcInterceptors, testhooks.GlobalScope); ok { - hook(serviceName, &customInterceptors, &customStreamInterceptors) - } - kep := keepalive.EnforcementPolicy{ MinTime: serviceConfig.KeepAliveMinTime(), PermitWithoutStream: serviceConfig.KeepAlivePermitWithoutStream(), diff --git a/service/fx.go b/service/fx.go index 2da1160c9a9..52093494100 100644 --- a/service/fx.go +++ b/service/fx.go @@ -15,7 +15,6 @@ import ( "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/rpc/interceptor" "go.temporal.io/server/common/telemetry" - "go.temporal.io/server/common/testing/testhooks" "go.uber.org/fx" "google.golang.org/grpc" ) @@ -52,7 +51,6 @@ type ( ContextMetadataInterceptor *interceptor.ContextMetadataInterceptor `optional:"true"` AdditionalInterceptors []grpc.UnaryServerInterceptor `optional:"true"` AdditionalStreamInterceptors []grpc.StreamServerInterceptor `optional:"true"` - TestHooks testhooks.TestHooks } ) @@ -127,10 +125,6 @@ func NewPersistenceRateLimitingParams( func GrpcServerOptionsProvider( params GrpcServerOptionsParams, ) []grpc.ServerOption { - if hook, ok := testhooks.Get(params.TestHooks, testhooks.ServiceGrpcInterceptors, testhooks.GlobalScope); ok { - hook(params.ServiceName, ¶ms.AdditionalInterceptors, ¶ms.AdditionalStreamInterceptors) - } - grpcServerOptions, err := params.RPCFactory.GetInternodeGRPCServerOptions() if err != nil { params.Logger.Fatal("creating gRPC server options failed", tag.Error(err)) diff --git a/service/history/handler.go b/service/history/handler.go index a15e968f148..13612a3cab8 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -1912,6 +1912,7 @@ func (h *Handler) StreamWorkflowReplicationMessages( replication.NewClusterShardKey(clientClusterShardID.ClusterID, clientClusterShardID.ShardID), replication.NewClusterShardKey(serverClusterShardID.ClusterID, serverClusterShardID.ShardID), h.config, + h.testHooks, ) streamSender.Start() h.streamReceiverMonitor.RegisterInboundStream(streamSender) diff --git a/service/history/replication/stream_observer.go b/service/history/replication/stream_observer.go new file mode 100644 index 00000000000..b0aed170938 --- /dev/null +++ b/service/history/replication/stream_observer.go @@ -0,0 +1,39 @@ +package replication + +import ( + "go.temporal.io/server/common/testing/testhooks" + "google.golang.org/protobuf/proto" +) + +const ( + adminStreamWorkflowReplicationMessagesMethod = "/temporal.server.api.adminservice.v1.AdminService/StreamWorkflowReplicationMessages" + historyStreamWorkflowReplicationMessagesMethod = "/temporal.server.api.historyservice.v1.HistoryService/StreamWorkflowReplicationMessages" +) + +func hasReplicationStreamMessageObserver(testHooks testhooks.TestHooks) bool { + _, ok := testhooks.Get(testHooks, testhooks.ReplicationStreamMessageObserver, testhooks.GlobalScope) + return ok +} + +func observeReplicationStreamMessage( + testHooks testhooks.TestHooks, + method string, + direction testhooks.ReplicationStreamMessageDirection, + clusterName string, + targetAddress string, + msg proto.Message, +) { + if msg == nil { + return + } + if hook, ok := testhooks.Get(testHooks, testhooks.ReplicationStreamMessageObserver, testhooks.GlobalScope); ok { + hook(testhooks.ReplicationStreamMessage{ + Method: method, + Direction: direction, + ClusterName: clusterName, + TargetAddress: targetAddress, + Message: msg, + IsStreamCall: true, + }) + } +} diff --git a/service/history/replication/stream_receiver.go b/service/history/replication/stream_receiver.go index d0bf78622b4..0008402a9e8 100644 --- a/service/history/replication/stream_receiver.go +++ b/service/history/replication/stream_receiver.go @@ -293,7 +293,7 @@ func (r *StreamReceiverImpl) ackMessage( return 0, NewStreamError("InclusiveLowWaterMark is not set", serviceerror.NewInternal("Invalid inclusive low watermark")) } - if err := stream.Send(&adminservice.StreamWorkflowReplicationMessagesRequest{ + req := &adminservice.StreamWorkflowReplicationMessagesRequest{ Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ SyncReplicationState: &replicationspb.SyncReplicationState{ InclusiveLowWatermark: inclusiveLowWaterMark, @@ -302,7 +302,18 @@ func (r *StreamReceiverImpl) ackMessage( LowPriorityState: lowPriorityWatermark, }, }, - }); err != nil { + } + if hasReplicationStreamMessageObserver(r.TestHooks) { + observeReplicationStreamMessage( + r.TestHooks, + adminStreamWorkflowReplicationMessagesMethod, + testhooks.ReplicationStreamDirectionSend, + r.ClusterMetadata.GetCurrentClusterName(), + r.ClusterMetadata.ClusterNameForFailoverVersion(true, int64(r.serverShardKey.ClusterID)), + req, + ) + } + if err := stream.Send(req); err != nil { return 0, NewStreamError("stream_receiver failed to send", err) } metrics.ReplicationTasksRecvBacklog.With(r.MetricsHandler).Record( @@ -341,6 +352,16 @@ func (r *StreamReceiverImpl) processMessages( if streamResp.Err != nil { return streamResp.Err } + if hasReplicationStreamMessageObserver(r.TestHooks) { + observeReplicationStreamMessage( + r.TestHooks, + adminStreamWorkflowReplicationMessagesMethod, + testhooks.ReplicationStreamDirectionRecv, + r.ClusterMetadata.GetCurrentClusterName(), + clusterName, + streamResp.Resp, + ) + } messages := streamResp.Resp.GetMessages() priority := messages.Priority diff --git a/service/history/replication/stream_sender.go b/service/history/replication/stream_sender.go index 928034359f5..1d6b9aa3530 100644 --- a/service/history/replication/stream_sender.go +++ b/service/history/replication/stream_sender.go @@ -27,6 +27,7 @@ import ( "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/quotas" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" historyi "go.temporal.io/server/service/history/interfaces" "go.temporal.io/server/service/history/shard" @@ -59,6 +60,7 @@ type ( recvSignalChan chan struct{} shutdownChan channel.ShutdownOnce config *configs.Config + testHooks testhooks.TestHooks isTieredStackEnabled bool flowController SenderFlowController sendLock sync.Mutex @@ -77,6 +79,7 @@ func NewStreamSender( clientShardKey ClusterShardKey, serverShardKey ClusterShardKey, config *configs.Config, + testHooks testhooks.TestHooks, ) *StreamSenderImpl { logger := log.With( shardContext.GetLogger(), @@ -100,6 +103,7 @@ func NewStreamSender( recvSignalChan: make(chan struct{}, 1), shutdownChan: channel.NewShutdownOnce(), config: config, + testHooks: testHooks, isTieredStackEnabled: config.EnableReplicationTaskTieredProcessing(), flowController: NewSenderFlowController(config, logger), ssRateLimiter: ssRateLimiter, @@ -189,6 +193,16 @@ func (s *StreamSenderImpl) recvEventLoop() (retErr error) { if err != nil { return NewStreamError("StreamSender failed to receive", err) } + if hasReplicationStreamMessageObserver(s.testHooks) { + observeReplicationStreamMessage( + s.testHooks, + historyStreamWorkflowReplicationMessagesMethod, + testhooks.ReplicationStreamDirectionServerRecv, + s.shardContext.GetClusterMetadata().GetCurrentClusterName(), + s.clientClusterName, + req, + ) + } switch attr := req.GetAttributes().(type) { case *historyservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState: if err := s.recvSyncReplicationState(attr.SyncReplicationState); err != nil { @@ -663,6 +677,16 @@ Loop: func (s *StreamSenderImpl) sendToStream(payload *historyservice.StreamWorkflowReplicationMessagesResponse) error { s.sendLock.Lock() defer s.sendLock.Unlock() + if hasReplicationStreamMessageObserver(s.testHooks) { + observeReplicationStreamMessage( + s.testHooks, + historyStreamWorkflowReplicationMessagesMethod, + testhooks.ReplicationStreamDirectionServerSend, + s.shardContext.GetClusterMetadata().GetCurrentClusterName(), + s.clientClusterName, + payload, + ) + } err := s.server.Send(payload) if err != nil { return NewStreamError("Stream Sender unable to send", err) diff --git a/service/history/replication/stream_sender_test.go b/service/history/replication/stream_sender_test.go index de504979221..2becaa0a0d7 100644 --- a/service/history/replication/stream_sender_test.go +++ b/service/history/replication/stream_sender_test.go @@ -24,6 +24,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/quotas" serviceerrors "go.temporal.io/server/common/serviceerror" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/configs" historyi "go.temporal.io/server/service/history/interfaces" "go.temporal.io/server/service/history/shard" @@ -91,6 +92,7 @@ func (s *streamSenderSuite) SetupTest() { s.clientShardKey, s.serverShardKey, s.config, + testhooks.TestHooks{}, ) s.senderFlowController = NewMockSenderFlowController(s.controller) s.streamSender.flowController = s.senderFlowController diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 0da9d05262d..08a5e43547d 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -44,7 +44,6 @@ import ( "go.temporal.io/server/components/nexusoperations" "go.temporal.io/server/temporal" "go.uber.org/multierr" - "google.golang.org/grpc" ) type ( @@ -141,6 +140,12 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { impl.hostsByProtocolByService[grpcProtocol], impl.tlsConfigProvider, ) + _ = testhooks.Set( + impl.testHooks, + testhooks.ReplicationStreamMessageObserver, + impl.replicationStreamRecorder.Observe, + testhooks.GlobalScope, + ) // Configure output file path for on-demand logging (call WriteToLog() to write) clusterName := params.Config.ClusterMetadata.CurrentClusterName @@ -302,11 +307,6 @@ func (c *TemporalImpl) serverOptionsForHost( if c.tokenProvider != nil { options = append(options, temporal.WithTokenProvider(c.tokenProvider)) } - if serviceName == primitives.FrontendService && c.replicationStreamRecorder != nil { - options = append(options, temporal.WithChainedFrontendGrpcInterceptors( - c.replicationStreamRecorder.UnaryServerInterceptor(c.config.ClusterMetadata.CurrentClusterName), - )) - } options = append(options, c.serverOptions...) return options } @@ -327,53 +327,6 @@ func (c *TemporalImpl) installHostTestHooks( }, testhooks.GlobalScope, )) - addCleanup(testhooks.Set( - c.testHooks, - testhooks.ServiceClientDialOptions, - func(options map[primitives.ServiceName][]grpc.DialOption) { - dialOptions := c.clientDialOptions() - if len(dialOptions) == 0 { - return - } - for _, serviceName := range []primitives.ServiceName{ - primitives.FrontendService, - primitives.InternalFrontendService, - primitives.HistoryService, - primitives.MatchingService, - } { - options[serviceName] = append(options[serviceName], dialOptions...) - } - }, - testhooks.GlobalScope, - )) - addCleanup(testhooks.Set( - c.testHooks, - testhooks.ServiceGrpcInterceptors, - func(name primitives.ServiceName, unaryInterceptors *[]grpc.UnaryServerInterceptor, streamInterceptors *[]grpc.StreamServerInterceptor) { - switch name { - case primitives.FrontendService: - if c.replicationStreamRecorder != nil { - *streamInterceptors = append( - *streamInterceptors, - c.replicationStreamRecorder.StreamServerInterceptor(c.config.ClusterMetadata.CurrentClusterName), - ) - } - case primitives.HistoryService: - if c.replicationStreamRecorder != nil { - *unaryInterceptors = append( - *unaryInterceptors, - c.replicationStreamRecorder.UnaryServerInterceptor(c.config.ClusterMetadata.CurrentClusterName), - ) - *streamInterceptors = append( - *streamInterceptors, - c.replicationStreamRecorder.StreamServerInterceptor(c.config.ClusterMetadata.CurrentClusterName), - ) - } - default: - } - }, - testhooks.GlobalScope, - )) addCleanup(testhooks.Set( c.testHooks, testhooks.MatchingRawClientCreated, @@ -418,17 +371,6 @@ func (c *TemporalImpl) installHostTestHooks( } } -func (c *TemporalImpl) clientDialOptions() []grpc.DialOption { - var options []grpc.DialOption - if c.replicationStreamRecorder != nil { - options = append(options, - grpc.WithChainUnaryInterceptor(c.replicationStreamRecorder.UnaryInterceptor(c.config.ClusterMetadata.CurrentClusterName)), - grpc.WithChainStreamInterceptor(c.replicationStreamRecorder.StreamInterceptor(c.config.ClusterMetadata.CurrentClusterName)), - ) - } - return options -} - func (c *TemporalImpl) configForHost(serviceName primitives.ServiceName, host string) *config.Config { bindIP, port := mustSplitHostPort(host) rpcConfig := config.RPC{ diff --git a/tests/testcore/replication_stream_recorder.go b/tests/testcore/replication_stream_recorder.go index 7b53d43c612..2a0d4b7909b 100644 --- a/tests/testcore/replication_stream_recorder.go +++ b/tests/testcore/replication_stream_recorder.go @@ -1,7 +1,6 @@ package testcore import ( - "context" "encoding/json" "errors" "fmt" @@ -9,7 +8,7 @@ import ( "sync" "time" - "google.golang.org/grpc" + "go.temporal.io/server/common/testing/testhooks" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) @@ -100,7 +99,21 @@ func (r *ReplicationStreamRecorder) GetMessages() []CapturedReplicationMessage { return result } +func (r *ReplicationStreamRecorder) Observe(message testhooks.ReplicationStreamMessage) { + r.recordMessage( + message.Method, + message.Message, + string(message.Direction), + message.ClusterName, + message.TargetAddress, + message.IsStreamCall, + ) +} + func (r *ReplicationStreamRecorder) recordMessage(method string, msg proto.Message, direction string, clusterName string, targetAddr string, isStreamCall bool) { + if msg == nil { + return + } r.mu.Lock() defer r.mu.Unlock() @@ -154,177 +167,3 @@ func (r *ReplicationStreamRecorder) formatCapturedMessage(captured CapturedRepli return string(jsonOutput) } - -// UnaryInterceptor returns a gRPC unary client interceptor that captures messages -func (r *ReplicationStreamRecorder) UnaryInterceptor(clusterName string) grpc.UnaryClientInterceptor { - return func( - ctx context.Context, - method string, - req, reply any, - cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, - opts ...grpc.CallOption, - ) error { - target := cc.Target() - - // Capture outgoing request if it's a replication-related call - if isReplicationMethod(method) { - if protoReq, ok := req.(proto.Message); ok { - r.recordMessage(method, protoReq, DirectionSend, clusterName, target, false) - } - } - - err := invoker(ctx, method, req, reply, cc, opts...) - - // Capture incoming response if successful - if err == nil && isReplicationMethod(method) { - if protoReply, ok := reply.(proto.Message); ok { - r.recordMessage(method, protoReply, DirectionRecv, clusterName, target, false) - } - } - - return err - } -} - -// StreamInterceptor returns a gRPC stream client interceptor that captures stream messages -func (r *ReplicationStreamRecorder) StreamInterceptor(clusterName string) grpc.StreamClientInterceptor { - return func( - ctx context.Context, - desc *grpc.StreamDesc, - cc *grpc.ClientConn, - method string, - streamer grpc.Streamer, - opts ...grpc.CallOption, - ) (grpc.ClientStream, error) { - stream, err := streamer(ctx, desc, cc, method, opts...) - if err != nil { - return nil, err - } - - if isReplicationMethod(method) { - return &recordingClientStream{ - ClientStream: stream, - recorder: r, - method: method, - clusterName: clusterName, - targetAddress: cc.Target(), - }, nil - } - - return stream, nil - } -} - -// recordingClientStream wraps a grpc.ClientStream to record messages -type recordingClientStream struct { - grpc.ClientStream - recorder *ReplicationStreamRecorder - method string - clusterName string - targetAddress string -} - -func (s *recordingClientStream) SendMsg(m any) error { - if msg, ok := m.(proto.Message); ok { - // SendMsg means this cluster is SENDING a message (could be request or ack) - s.recorder.recordMessage(s.method, msg, DirectionSend, s.clusterName, s.targetAddress, true) - } - return s.ClientStream.SendMsg(m) -} - -func (s *recordingClientStream) RecvMsg(m any) error { - err := s.ClientStream.RecvMsg(m) - if err == nil { - if msg, ok := m.(proto.Message); ok { - // RecvMsg means this cluster is RECEIVING a message (could be request or data) - s.recorder.recordMessage(s.method, msg, DirectionRecv, s.clusterName, s.targetAddress, true) - } - } - return err -} - -// UnaryServerInterceptor returns a gRPC unary server interceptor that captures messages -func (r *ReplicationStreamRecorder) UnaryServerInterceptor(clusterName string) grpc.UnaryServerInterceptor { - return func( - ctx context.Context, - req any, - info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler, - ) (any, error) { - // Capture incoming request if it's a replication-related call - if isReplicationMethod(info.FullMethod) { - if protoReq, ok := req.(proto.Message); ok { - r.recordMessage(info.FullMethod, protoReq, DirectionServerRecv, clusterName, "server", false) - } - } - - resp, err := handler(ctx, req) - - // Capture outgoing response if successful - if err == nil && isReplicationMethod(info.FullMethod) { - if protoResp, ok := resp.(proto.Message); ok { - r.recordMessage(info.FullMethod, protoResp, DirectionServerSend, clusterName, "server", false) - } - } - - return resp, err - } -} - -// StreamServerInterceptor returns a gRPC stream server interceptor that captures stream messages -func (r *ReplicationStreamRecorder) StreamServerInterceptor(clusterName string) grpc.StreamServerInterceptor { - return func( - srv any, - ss grpc.ServerStream, - info *grpc.StreamServerInfo, - handler grpc.StreamHandler, - ) error { - if isReplicationMethod(info.FullMethod) { - wrappedStream := &recordingServerStream{ - ServerStream: ss, - recorder: r, - method: info.FullMethod, - clusterName: clusterName, - } - return handler(srv, wrappedStream) - } - - return handler(srv, ss) - } -} - -// recordingServerStream wraps a grpc.ServerStream to record messages -type recordingServerStream struct { - grpc.ServerStream - recorder *ReplicationStreamRecorder - method string - clusterName string -} - -func (s *recordingServerStream) SendMsg(m any) error { - if msg, ok := m.(proto.Message); ok { - // Server SendMsg means this server is SENDING a message to the client - s.recorder.recordMessage(s.method, msg, DirectionServerSend, s.clusterName, "server", true) - } - return s.ServerStream.SendMsg(m) -} - -func (s *recordingServerStream) RecvMsg(m any) error { - err := s.ServerStream.RecvMsg(m) - if err == nil { - if msg, ok := m.(proto.Message); ok { - // Server RecvMsg means this server is RECEIVING a message from the client - s.recorder.recordMessage(s.method, msg, DirectionServerRecv, s.clusterName, "server", true) - } - } - return err -} - -func isReplicationMethod(method string) bool { - // Capture StreamWorkflowReplicationMessages from both history and admin services - // - Sender (active) uses history service to respond to receiver - // - Receiver (standby) uses admin service to call sender - return method == "/temporal.server.api.historyservice.v1.HistoryService/StreamWorkflowReplicationMessages" || - method == "/temporal.server.api.adminservice.v1.AdminService/StreamWorkflowReplicationMessages" -} From 3da7a46736abc145ad88b4216b5ee5010f058220 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 20:23:26 -0700 Subject: [PATCH 04/16] Narrow task queue recorder hook --- common/persistence/client/fx.go | 5 +- .../client/history_tasks_observer.go | 110 +++++++ common/testing/testhooks/hooks.go | 12 +- tests/testcore/onebox.go | 18 +- tests/testcore/task_queue_recorder.go | 282 +----------------- tests/testcore/test_cluster.go | 3 - 6 files changed, 143 insertions(+), 287 deletions(-) create mode 100644 common/persistence/client/history_tasks_observer.go diff --git a/common/persistence/client/fx.go b/common/persistence/client/fx.go index 3cc9b9228ce..a6a9e228d29 100644 --- a/common/persistence/client/fx.go +++ b/common/persistence/client/fx.go @@ -224,7 +224,6 @@ type managerProviderParams struct { Factory Factory Lifecycle fx.Lifecycle TestHooks testhooks.TestHooks `optional:"true"` - Logger log.Logger } func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, error)) func(managerProviderParams) (T, error) { @@ -240,8 +239,8 @@ func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, err return nilT, err } if executionManager, ok := any(manager).(persistence.ExecutionManager); ok { - if hook, ok := testhooks.Get(params.TestHooks, testhooks.PersistenceExecutionManagerWrapper, testhooks.GlobalScope); ok { - manager = any(hook(executionManager, params.Logger)).(T) + if hook, ok := testhooks.Get(params.TestHooks, testhooks.HistoryTasksWrittenObserver, testhooks.GlobalScope); ok { + manager = any(newHistoryTasksWrittenObserver(executionManager, hook)).(T) } } params.Lifecycle.Append(fx.StopHook(manager.Close)) diff --git a/common/persistence/client/history_tasks_observer.go b/common/persistence/client/history_tasks_observer.go new file mode 100644 index 00000000000..99cc4a95206 --- /dev/null +++ b/common/persistence/client/history_tasks_observer.go @@ -0,0 +1,110 @@ +package client + +import ( + "context" + + "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/testing/testhooks" +) + +type historyTasksWrittenObserver struct { + persistence.ExecutionManager + observe func(testhooks.HistoryTasksWritten) +} + +func newHistoryTasksWrittenObserver( + manager persistence.ExecutionManager, + observe func(testhooks.HistoryTasksWritten), +) persistence.ExecutionManager { + return &historyTasksWrittenObserver{ + ExecutionManager: manager, + observe: observe, + } +} + +func (o *historyTasksWrittenObserver) AddHistoryTasks( + ctx context.Context, + request *persistence.AddHistoryTasksRequest, +) error { + err := o.ExecutionManager.AddHistoryTasks(ctx, request) + if err == nil && request != nil { + o.observe(testhooks.HistoryTasksWritten{ + ShardID: request.ShardID, + RangeID: request.RangeID, + NamespaceID: request.NamespaceID, + WorkflowID: request.WorkflowID, + Tasks: request.Tasks, + }) + } + return err +} + +func (o *historyTasksWrittenObserver) CreateWorkflowExecution( + ctx context.Context, + request *persistence.CreateWorkflowExecutionRequest, +) (*persistence.CreateWorkflowExecutionResponse, error) { + response, err := o.ExecutionManager.CreateWorkflowExecution(ctx, request) + if err == nil && request != nil { + o.observeWorkflowSnapshot(request.ShardID, request.RangeID, &request.NewWorkflowSnapshot) + } + return response, err +} + +func (o *historyTasksWrittenObserver) UpdateWorkflowExecution( + ctx context.Context, + request *persistence.UpdateWorkflowExecutionRequest, +) (*persistence.UpdateWorkflowExecutionResponse, error) { + response, err := o.ExecutionManager.UpdateWorkflowExecution(ctx, request) + if err == nil && request != nil { + o.observeWorkflowMutation(request.ShardID, request.RangeID, &request.UpdateWorkflowMutation) + o.observeWorkflowSnapshot(request.ShardID, request.RangeID, request.NewWorkflowSnapshot) + } + return response, err +} + +func (o *historyTasksWrittenObserver) ConflictResolveWorkflowExecution( + ctx context.Context, + request *persistence.ConflictResolveWorkflowExecutionRequest, +) (*persistence.ConflictResolveWorkflowExecutionResponse, error) { + response, err := o.ExecutionManager.ConflictResolveWorkflowExecution(ctx, request) + if err == nil && request != nil { + o.observeWorkflowSnapshot(request.ShardID, request.RangeID, &request.ResetWorkflowSnapshot) + o.observeWorkflowSnapshot(request.ShardID, request.RangeID, request.NewWorkflowSnapshot) + o.observeWorkflowMutation(request.ShardID, request.RangeID, request.CurrentWorkflowMutation) + } + return response, err +} + +func (o *historyTasksWrittenObserver) observeWorkflowSnapshot( + shardID int32, + rangeID int64, + snapshot *persistence.WorkflowSnapshot, +) { + if snapshot == nil || snapshot.ExecutionInfo == nil { + return + } + o.observe(testhooks.HistoryTasksWritten{ + ShardID: shardID, + RangeID: rangeID, + NamespaceID: snapshot.ExecutionInfo.NamespaceId, + WorkflowID: snapshot.ExecutionInfo.WorkflowId, + Tasks: snapshot.Tasks, + }) +} + +func (o *historyTasksWrittenObserver) observeWorkflowMutation( + shardID int32, + rangeID int64, + mutation *persistence.WorkflowMutation, +) { + if mutation == nil || mutation.ExecutionInfo == nil { + return + } + o.observe(testhooks.HistoryTasksWritten{ + ShardID: shardID, + RangeID: rangeID, + NamespaceID: mutation.ExecutionInfo.NamespaceId, + WorkflowID: mutation.ExecutionInfo.WorkflowId, + Tasks: mutation.Tasks, + }) +} diff --git a/common/testing/testhooks/hooks.go b/common/testing/testhooks/hooks.go index db95096b493..9498bef8cf2 100644 --- a/common/testing/testhooks/hooks.go +++ b/common/testing/testhooks/hooks.go @@ -9,9 +9,7 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/chasm" - "go.temporal.io/server/common/log" "go.temporal.io/server/common/namespace" - "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/primitives" historytasks "go.temporal.io/server/service/history/tasks" "google.golang.org/protobuf/proto" @@ -34,6 +32,14 @@ type ( Message proto.Message IsStreamCall bool } + + HistoryTasksWritten struct { + ShardID int32 + RangeID int64 + NamespaceID string + WorkflowID string + Tasks map[historytasks.Category][]historytasks.Task + } ) const ( @@ -60,11 +66,11 @@ var ( HistoryTransferTaskInterceptor = newKey[func(historytasks.Task, func()), namespace.ID]() HistoryDLQTaskDeleteInterceptor = newKey[func(context.Context, *historyservice.DeleteDLQTasksRequest, func(context.Context, *historyservice.DeleteDLQTasksRequest) (*historyservice.DeleteDLQTasksResponse, error)) (*historyservice.DeleteDLQTasksResponse, error), global]() NamespaceReplicationTaskInterceptor = newKey[func(context.Context, *replicationspb.NamespaceTaskAttributes, func() error) error, namespace.Name]() + HistoryTasksWrittenObserver = newKey[func(HistoryTasksWritten), global]() ReplicationStreamMessageObserver = newKey[func(ReplicationStreamMessage), global]() MatchingRawClientCreated = newKey[func(primitives.ServiceName, matchingservice.MatchingServiceClient), global]() ChasmRegistryInitializer = newKey[func(*chasm.Registry) error, global]() HistoryChasmComponentsCreated = newKey[func(HistoryChasmComponents), global]() - PersistenceExecutionManagerWrapper = newKey[func(persistence.ExecutionManager, log.Logger) persistence.ExecutionManager, global]() ) // keyID is a unique identifier for a key, used as a map key. diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 08a5e43547d..a3e41cf6607 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -33,7 +33,6 @@ import ( "go.temporal.io/server/common/membership/static" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/metrics/metricstest" - "go.temporal.io/server/common/persistence" persistenceClient "go.temporal.io/server/common/persistence/client" "go.temporal.io/server/common/persistence/visibility" "go.temporal.io/server/common/primitives" @@ -71,6 +70,7 @@ type ( callbackLock sync.RWMutex // Must be used for above callbacks chasmRegistry *chasm.Registry replicationStreamRecorder *ReplicationStreamRecorder + taskQueueRecorderOnce sync.Once taskQueueRecorder *TaskQueueRecorder servers []*temporal.ServerFx @@ -352,13 +352,8 @@ func (c *TemporalImpl) installHostTestHooks( )) addCleanup(testhooks.Set( c.testHooks, - testhooks.PersistenceExecutionManagerWrapper, - func(base persistence.ExecutionManager, logger log.Logger) persistence.ExecutionManager { - // Wrap ExecutionManager with recorder to capture task writes - // This wraps the FINAL ExecutionManager after all FX processing (metrics, retries, etc.) - c.taskQueueRecorder = NewTaskQueueRecorder(base, logger) - return c.taskQueueRecorder - }, + testhooks.HistoryTasksWrittenObserver, + c.getTaskQueueRecorder().Observe, testhooks.GlobalScope, )) default: @@ -398,6 +393,13 @@ func (c *TemporalImpl) configForHost(serviceName primitives.ServiceName, host st } func (c *TemporalImpl) GetTaskQueueRecorder() *TaskQueueRecorder { + return c.getTaskQueueRecorder() +} + +func (c *TemporalImpl) getTaskQueueRecorder() *TaskQueueRecorder { + c.taskQueueRecorderOnce.Do(func() { + c.taskQueueRecorder = NewTaskQueueRecorder(c.logger) + }) return c.taskQueueRecorder } diff --git a/tests/testcore/task_queue_recorder.go b/tests/testcore/task_queue_recorder.go index 2e08d478ba1..bf006a21907 100644 --- a/tests/testcore/task_queue_recorder.go +++ b/tests/testcore/task_queue_recorder.go @@ -1,7 +1,6 @@ package testcore import ( - "context" "encoding/json" "fmt" "os" @@ -9,21 +8,20 @@ import ( "time" "go.temporal.io/server/common/log" - "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/tasks" ) -// TaskQueueRecorder wraps an ExecutionManager to record ALL task writes +// TaskQueueRecorder records ALL task writes // to the history task queues (transfer, timer, replication, visibility, archival, etc.). // This is useful for integration tests where you want to assert on what tasks // were generated and in what order. // Tasks are stored flattened by category - all tasks of the same type are in a single list, // with each task wrapped with metadata about when/where it was written. type TaskQueueRecorder struct { - mu sync.RWMutex - tasks map[tasks.Category][]RecordedTask // All tasks by category, in order - delegate persistence.ExecutionManager - logger log.Logger + mu sync.RWMutex + tasks map[tasks.Category][]RecordedTask // All tasks by category, in order + logger log.Logger } // RecordedTask wraps a task with metadata about when and where it was written @@ -38,83 +36,17 @@ type RecordedTask struct { Task tasks.Task `json:"task"` // The actual task object } -// NewTaskQueueRecorder creates a recorder that wraps the given ExecutionManager -func NewTaskQueueRecorder(delegate persistence.ExecutionManager, logger log.Logger) *TaskQueueRecorder { +// NewTaskQueueRecorder creates a recorder for history task writes. +func NewTaskQueueRecorder(logger log.Logger) *TaskQueueRecorder { return &TaskQueueRecorder{ - tasks: make(map[tasks.Category][]RecordedTask), - delegate: delegate, - logger: logger, + tasks: make(map[tasks.Category][]RecordedTask), + logger: logger, } } -// AddHistoryTasks records the task write and then delegates to the underlying manager -func (r *TaskQueueRecorder) AddHistoryTasks( - ctx context.Context, - request *persistence.AddHistoryTasksRequest, -) error { - // Call the delegate first - err := r.delegate.AddHistoryTasks(ctx, request) - - // Only record if successful - if err == nil { - r.recordTasks(request.ShardID, 0, request.NamespaceID, request.WorkflowID, request.Tasks) - } - - return err -} - -func (r *TaskQueueRecorder) UpdateWorkflowExecution( - ctx context.Context, - request *persistence.UpdateWorkflowExecutionRequest, -) (*persistence.UpdateWorkflowExecutionResponse, error) { - // Call the delegate first - resp, err := r.delegate.UpdateWorkflowExecution(ctx, request) - - // Only record if successful - if err == nil { - // Record tasks from the mutation - r.recordTasks( - request.ShardID, - request.RangeID, - request.UpdateWorkflowMutation.ExecutionInfo.NamespaceId, - request.UpdateWorkflowMutation.ExecutionInfo.WorkflowId, - request.UpdateWorkflowMutation.Tasks, - ) - - // Record tasks from new workflow snapshot if present - if request.NewWorkflowSnapshot != nil { - r.recordTasks( - request.ShardID, - request.RangeID, - request.NewWorkflowSnapshot.ExecutionInfo.NamespaceId, - request.NewWorkflowSnapshot.ExecutionInfo.WorkflowId, - request.NewWorkflowSnapshot.Tasks, - ) - } - } - - return resp, err -} - -func (r *TaskQueueRecorder) CreateWorkflowExecution( - ctx context.Context, - request *persistence.CreateWorkflowExecutionRequest, -) (*persistence.CreateWorkflowExecutionResponse, error) { - // Call the delegate first - resp, err := r.delegate.CreateWorkflowExecution(ctx, request) - - // Only record if successful - if err == nil { - r.recordTasks( - request.ShardID, - request.RangeID, - request.NewWorkflowSnapshot.ExecutionInfo.NamespaceId, - request.NewWorkflowSnapshot.ExecutionInfo.WorkflowId, - request.NewWorkflowSnapshot.Tasks, - ) - } - - return resp, err +// Observe records one successful history task write. +func (r *TaskQueueRecorder) Observe(write testhooks.HistoryTasksWritten) { + r.recordTasks(write.ShardID, write.RangeID, write.NamespaceID, write.WorkflowID, write.Tasks) } // recordTasks appends tasks to the flattened list by category, wrapping each with metadata @@ -380,193 +312,3 @@ func writeFile(filePath string, data []byte) error { _, err = file.Write(data) return err } - -// Delegate all other ExecutionManager methods to the underlying implementation -// These are pass-through methods that don't need recording - -func (r *TaskQueueRecorder) GetName() string { - return r.delegate.GetName() -} - -func (r *TaskQueueRecorder) Close() { - r.delegate.Close() -} - -func (r *TaskQueueRecorder) GetWorkflowExecution( - ctx context.Context, - request *persistence.GetWorkflowExecutionRequest, -) (*persistence.GetWorkflowExecutionResponse, error) { - return r.delegate.GetWorkflowExecution(ctx, request) -} - -func (r *TaskQueueRecorder) ConflictResolveWorkflowExecution( - ctx context.Context, - request *persistence.ConflictResolveWorkflowExecutionRequest, -) (*persistence.ConflictResolveWorkflowExecutionResponse, error) { - return r.delegate.ConflictResolveWorkflowExecution(ctx, request) -} - -func (r *TaskQueueRecorder) DeleteWorkflowExecution( - ctx context.Context, - request *persistence.DeleteWorkflowExecutionRequest, -) error { - return r.delegate.DeleteWorkflowExecution(ctx, request) -} - -func (r *TaskQueueRecorder) DeleteCurrentWorkflowExecution( - ctx context.Context, - request *persistence.DeleteCurrentWorkflowExecutionRequest, -) error { - return r.delegate.DeleteCurrentWorkflowExecution(ctx, request) -} - -func (r *TaskQueueRecorder) GetCurrentExecution( - ctx context.Context, - request *persistence.GetCurrentExecutionRequest, -) (*persistence.GetCurrentExecutionResponse, error) { - return r.delegate.GetCurrentExecution(ctx, request) -} - -func (r *TaskQueueRecorder) SetWorkflowExecution( - ctx context.Context, - request *persistence.SetWorkflowExecutionRequest, -) (*persistence.SetWorkflowExecutionResponse, error) { - return r.delegate.SetWorkflowExecution(ctx, request) -} - -func (r *TaskQueueRecorder) ListConcreteExecutions( - ctx context.Context, - request *persistence.ListConcreteExecutionsRequest, -) (*persistence.ListConcreteExecutionsResponse, error) { - return r.delegate.ListConcreteExecutions(ctx, request) -} - -func (r *TaskQueueRecorder) GetHistoryTasks( - ctx context.Context, - request *persistence.GetHistoryTasksRequest, -) (*persistence.GetHistoryTasksResponse, error) { - return r.delegate.GetHistoryTasks(ctx, request) -} - -func (r *TaskQueueRecorder) CompleteHistoryTask( - ctx context.Context, - request *persistence.CompleteHistoryTaskRequest, -) error { - return r.delegate.CompleteHistoryTask(ctx, request) -} - -func (r *TaskQueueRecorder) RangeCompleteHistoryTasks( - ctx context.Context, - request *persistence.RangeCompleteHistoryTasksRequest, -) error { - return r.delegate.RangeCompleteHistoryTasks(ctx, request) -} - -func (r *TaskQueueRecorder) PutReplicationTaskToDLQ( - ctx context.Context, - request *persistence.PutReplicationTaskToDLQRequest, -) error { - return r.delegate.PutReplicationTaskToDLQ(ctx, request) -} - -func (r *TaskQueueRecorder) GetReplicationTasksFromDLQ( - ctx context.Context, - request *persistence.GetReplicationTasksFromDLQRequest, -) (*persistence.GetHistoryTasksResponse, error) { - return r.delegate.GetReplicationTasksFromDLQ(ctx, request) -} - -func (r *TaskQueueRecorder) DeleteReplicationTaskFromDLQ( - ctx context.Context, - request *persistence.DeleteReplicationTaskFromDLQRequest, -) error { - return r.delegate.DeleteReplicationTaskFromDLQ(ctx, request) -} - -func (r *TaskQueueRecorder) RangeDeleteReplicationTaskFromDLQ( - ctx context.Context, - request *persistence.RangeDeleteReplicationTaskFromDLQRequest, -) error { - return r.delegate.RangeDeleteReplicationTaskFromDLQ(ctx, request) -} - -func (r *TaskQueueRecorder) IsReplicationDLQEmpty( - ctx context.Context, - request *persistence.GetReplicationTasksFromDLQRequest, -) (bool, error) { - return r.delegate.IsReplicationDLQEmpty(ctx, request) -} - -func (r *TaskQueueRecorder) GetHistoryBranchUtil() persistence.HistoryBranchUtil { - return r.delegate.GetHistoryBranchUtil() -} - -func (r *TaskQueueRecorder) AppendHistoryNodes( - ctx context.Context, - request *persistence.AppendHistoryNodesRequest, -) (*persistence.AppendHistoryNodesResponse, error) { - return r.delegate.AppendHistoryNodes(ctx, request) -} - -func (r *TaskQueueRecorder) AppendRawHistoryNodes( - ctx context.Context, - request *persistence.AppendRawHistoryNodesRequest, -) (*persistence.AppendHistoryNodesResponse, error) { - return r.delegate.AppendRawHistoryNodes(ctx, request) -} - -func (r *TaskQueueRecorder) ReadHistoryBranch( - ctx context.Context, - request *persistence.ReadHistoryBranchRequest, -) (*persistence.ReadHistoryBranchResponse, error) { - return r.delegate.ReadHistoryBranch(ctx, request) -} - -func (r *TaskQueueRecorder) ReadHistoryBranchByBatch( - ctx context.Context, - request *persistence.ReadHistoryBranchRequest, -) (*persistence.ReadHistoryBranchByBatchResponse, error) { - return r.delegate.ReadHistoryBranchByBatch(ctx, request) -} - -func (r *TaskQueueRecorder) ReadHistoryBranchReverse( - ctx context.Context, - request *persistence.ReadHistoryBranchReverseRequest, -) (*persistence.ReadHistoryBranchReverseResponse, error) { - return r.delegate.ReadHistoryBranchReverse(ctx, request) -} - -func (r *TaskQueueRecorder) ReadRawHistoryBranch( - ctx context.Context, - request *persistence.ReadHistoryBranchRequest, -) (*persistence.ReadRawHistoryBranchResponse, error) { - return r.delegate.ReadRawHistoryBranch(ctx, request) -} - -func (r *TaskQueueRecorder) ForkHistoryBranch( - ctx context.Context, - request *persistence.ForkHistoryBranchRequest, -) (*persistence.ForkHistoryBranchResponse, error) { - return r.delegate.ForkHistoryBranch(ctx, request) -} - -func (r *TaskQueueRecorder) DeleteHistoryBranch( - ctx context.Context, - request *persistence.DeleteHistoryBranchRequest, -) error { - return r.delegate.DeleteHistoryBranch(ctx, request) -} - -func (r *TaskQueueRecorder) TrimHistoryBranch( - ctx context.Context, - request *persistence.TrimHistoryBranchRequest, -) (*persistence.TrimHistoryBranchResponse, error) { - return r.delegate.TrimHistoryBranch(ctx, request) -} - -func (r *TaskQueueRecorder) GetAllHistoryTreeBranches( - ctx context.Context, - request *persistence.GetAllHistoryTreeBranchesRequest, -) (*persistence.GetAllHistoryTreeBranchesResponse, error) { - return r.delegate.GetAllHistoryTreeBranches(ctx, request) -} diff --git a/tests/testcore/test_cluster.go b/tests/testcore/test_cluster.go index 506a6d40429..f22e58b9cc0 100644 --- a/tests/testcore/test_cluster.go +++ b/tests/testcore/test_cluster.go @@ -491,9 +491,6 @@ func (tc *TestCluster) SchedulerClient() schedulerpb.SchedulerServiceClient { // ExecutionManager returns an execution manager factory from the test cluster func (tc *TestCluster) ExecutionManager() persistence.ExecutionManager { - if tc.host.taskQueueRecorder != nil { - return tc.host.taskQueueRecorder - } return tc.testBase.ExecutionManager } From 302ffae8d2711cd10112add35ff9b32be8c62542 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 20:25:54 -0700 Subject: [PATCH 05/16] Add CHASM library server option --- common/testing/testhooks/hooks.go | 1 - temporal/fx.go | 23 +++++++++++++++++------ temporal/server_option.go | 9 +++++++++ temporal/server_options.go | 2 ++ tests/testcore/onebox.go | 9 +-------- 5 files changed, 29 insertions(+), 15 deletions(-) diff --git a/common/testing/testhooks/hooks.go b/common/testing/testhooks/hooks.go index 9498bef8cf2..cddc4d14587 100644 --- a/common/testing/testhooks/hooks.go +++ b/common/testing/testhooks/hooks.go @@ -69,7 +69,6 @@ var ( HistoryTasksWrittenObserver = newKey[func(HistoryTasksWritten), global]() ReplicationStreamMessageObserver = newKey[func(ReplicationStreamMessage), global]() MatchingRawClientCreated = newKey[func(primitives.ServiceName, matchingservice.MatchingServiceClient), global]() - ChasmRegistryInitializer = newKey[func(*chasm.Registry) error, global]() HistoryChasmComponentsCreated = newKey[func(HistoryChasmComponents), global]() ) diff --git a/temporal/fx.go b/temporal/fx.go index ac9a6490489..3c1d420dbd1 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -127,6 +127,7 @@ type ( EsClient esclient.Client MetricsHandler metrics.Handler TestHooks testhooks.TestHooks + ChasmLibraries []chasm.Library } ) @@ -322,6 +323,7 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { EsClient: esClient, MetricsHandler: metricHandler, TestHooks: testHooks, + ChasmLibraries: so.chasmLibraries, }, nil } @@ -389,6 +391,7 @@ type ( StaticServiceHosts map[primitives.ServiceName]static.Hosts `optional:"true"` TaskCategoryRegistry tasks.TaskCategoryRegistry TestHooks testhooks.TestHooks + ChasmLibraries []chasm.Library } ) @@ -470,6 +473,9 @@ func (params ServiceProviderParamsCommon) GetCommonServiceOptions(serviceName pr func() tasks.TaskCategoryRegistry { return params.TaskCategoryRegistry }, + func() []chasm.Library { + return params.ChasmLibraries + }, ), fx.Decorate(func() testhooks.TestHooks { return params.TestHooks @@ -479,20 +485,25 @@ func (params ServiceProviderParamsCommon) GetCommonServiceOptions(serviceName pr membershipModule, FxLogAdapter, chasm.Module, - fx.Invoke(ChasmRegistryInitializerHook), + fx.Invoke(ChasmLibrariesInitializer), ) } -type chasmRegistryInitializerHookParams struct { +type chasmLibrariesInitializerParams struct { fx.In Registry *chasm.Registry - TestHooks testhooks.TestHooks + Libraries []chasm.Library } -func ChasmRegistryInitializerHook(params chasmRegistryInitializerHookParams) error { - if hook, ok := testhooks.Get(params.TestHooks, testhooks.ChasmRegistryInitializer, testhooks.GlobalScope); ok { - return hook(params.Registry) +func ChasmLibrariesInitializer(params chasmLibrariesInitializerParams) error { + for _, library := range params.Libraries { + if library == nil { + return errors.New("cannot register nil CHASM library") + } + if err := params.Registry.Register(library); err != nil { + return fmt.Errorf("register CHASM library %q: %w", library.Name(), err) + } } return nil } diff --git a/temporal/server_option.go b/temporal/server_option.go index 11d9d5a2c08..03f2f9fd6fd 100644 --- a/temporal/server_option.go +++ b/temporal/server_option.go @@ -3,6 +3,7 @@ package temporal import ( "net/http" + "go.temporal.io/server/chasm" "go.temporal.io/server/client" "go.temporal.io/server/common/archiver/provider" "go.temporal.io/server/common/authorization" @@ -213,6 +214,14 @@ func WithTestHooks(testHooks testhooks.TestHooks) ServerOption { }) } +// WithChasmLibraries registers additional CHASM libraries in each service graph. +// NOTE: this option is experimental and may be changed or removed in future release. +func WithChasmLibraries(libraries ...chasm.Library) ServerOption { + return applyFunc(func(s *serverOptions) { + s.chasmLibraries = append(s.chasmLibraries, libraries...) + }) +} + // WithCustomerMetricsProvider sets a custom implementation of the metrics.MetricsHandler interface // metrics.MetricsHandler is the base interface for publishing metric events func WithCustomMetricsHandler(provider metrics.Handler) ServerOption { diff --git a/temporal/server_options.go b/temporal/server_options.go index 96fa6d8ed3f..7d8687c46cb 100644 --- a/temporal/server_options.go +++ b/temporal/server_options.go @@ -6,6 +6,7 @@ import ( "net/http" "slices" + "go.temporal.io/server/chasm" "go.temporal.io/server/client" "go.temporal.io/server/common/archiver/provider" "go.temporal.io/server/common/authorization" @@ -62,6 +63,7 @@ type ( metricHandler metrics.Handler tokenProvider auth.TokenProvider testHooks testhooks.TestHooks + chasmLibraries []chasm.Library } ) diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index a3e41cf6607..4c2ce857a7d 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -300,6 +300,7 @@ func (c *TemporalImpl) serverOptionsForHost( temporal.WithSearchAttributesMapper(nil), temporal.WithPersistenceServiceResolver(resolver.NewNoopResolver()), temporal.WithCustomMetricsHandler(c.GetMetricsHandler()), + temporal.WithChasmLibraries(chasmtests.Library), } if c.tlsConfigProvider != nil { options = append(options, temporal.WithTLSConfigFactory(c.tlsConfigProvider)) @@ -319,14 +320,6 @@ func (c *TemporalImpl) installHostTestHooks( cleanups = append(cleanups, cleanup) } - addCleanup(testhooks.Set( - c.testHooks, - testhooks.ChasmRegistryInitializer, - func(registry *chasm.Registry) error { - return registry.Register(chasmtests.Library) - }, - testhooks.GlobalScope, - )) addCleanup(testhooks.Set( c.testHooks, testhooks.MatchingRawClientCreated, From a92bcadb6570b8261caa1aa70c7d9f66526b9dc6 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 20:29:47 -0700 Subject: [PATCH 06/16] Build onebox matching client directly --- common/resource/fx.go | 5 - common/testing/testhooks/hooks.go | 3 - tests/testcore/clients.go | 177 +++++++++++++++++++++++++++++- tests/testcore/onebox.go | 14 +-- 4 files changed, 177 insertions(+), 22 deletions(-) diff --git a/common/resource/fx.go b/common/resource/fx.go index b8a9d99e12e..03ec10713b0 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -328,18 +328,13 @@ func HistoryClientProvider(historyRawClient HistoryRawClient) HistoryClient { } func MatchingRawClientProvider( - serviceName primitives.ServiceName, clientBean client.Bean, namespaceRegistry namespace.Registry, - testHooks testhooks.TestHooks, ) (MatchingRawClient, error) { client, err := clientBean.GetMatchingClient(namespaceRegistry.GetNamespaceName) if err != nil { return nil, err } - if hook, ok := testhooks.Get(testHooks, testhooks.MatchingRawClientCreated, testhooks.GlobalScope); ok { - hook(serviceName, client) - } return client, nil } diff --git a/common/testing/testhooks/hooks.go b/common/testing/testhooks/hooks.go index cddc4d14587..fa389506bd9 100644 --- a/common/testing/testhooks/hooks.go +++ b/common/testing/testhooks/hooks.go @@ -5,12 +5,10 @@ import ( "time" "go.temporal.io/server/api/historyservice/v1" - "go.temporal.io/server/api/matchingservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/chasm" "go.temporal.io/server/common/namespace" - "go.temporal.io/server/common/primitives" historytasks "go.temporal.io/server/service/history/tasks" "google.golang.org/protobuf/proto" ) @@ -68,7 +66,6 @@ var ( NamespaceReplicationTaskInterceptor = newKey[func(context.Context, *replicationspb.NamespaceTaskAttributes, func() error) error, namespace.Name]() HistoryTasksWrittenObserver = newKey[func(HistoryTasksWritten), global]() ReplicationStreamMessageObserver = newKey[func(ReplicationStreamMessage), global]() - MatchingRawClientCreated = newKey[func(primitives.ServiceName, matchingservice.MatchingServiceClient), global]() HistoryChasmComponentsCreated = newKey[func(HistoryChasmComponents), global]() ) diff --git a/tests/testcore/clients.go b/tests/testcore/clients.go index 923b20e0526..822009eba21 100644 --- a/tests/testcore/clients.go +++ b/tests/testcore/clients.go @@ -5,19 +5,26 @@ import ( "fmt" "sync" + "github.com/dgryski/go-farm" "go.temporal.io/api/operatorservice/v1" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/api/matchingservice/v1" schedulerpb "go.temporal.io/server/chasm/lib/scheduler/gen/schedulerpb/v1" + "go.temporal.io/server/client/matching" + "go.temporal.io/server/common" + "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/membership" "go.temporal.io/server/common/membership/static" "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/rpc" "go.temporal.io/server/common/rpc/encryption" + "go.temporal.io/server/common/testing/testhooks" "google.golang.org/grpc" ) @@ -25,6 +32,9 @@ type clients struct { logger log.Logger hostsByService map[primitives.ServiceName]static.Hosts tlsConfigProvider *encryption.FixedTLSConfigProvider + dc *dynamicconfig.Collection + testHooks testhooks.TestHooks + metricsHandler metrics.Handler frontend frontendClients history historyClients @@ -47,18 +57,28 @@ type historyClients struct { } type matchingClient struct { + once sync.Once client matchingservice.MatchingServiceClient + + clientConnsLock sync.Mutex + clientConns []*grpc.ClientConn } func newClients( logger log.Logger, hostsByService map[primitives.ServiceName]static.Hosts, tlsConfigProvider *encryption.FixedTLSConfigProvider, + dc *dynamicconfig.Collection, + testHooks testhooks.TestHooks, + metricsHandler metrics.Handler, ) clients { return clients{ logger: logger, hostsByService: hostsByService, tlsConfigProvider: tlsConfigProvider, + dc: dc, + testHooks: testHooks, + metricsHandler: metricsHandler, } } @@ -113,9 +133,13 @@ func (c *clients) ensureHistory() { } func (c *clients) MatchingClient() matchingservice.MatchingServiceClient { - if c.matching.client == nil { - c.logger.Fatal("matching test client has not been initialized") - } + c.matching.once.Do(func() { + client, err := c.newMatchingClient() + if err != nil { + c.logger.Fatal("unable to create matching test client", tag.Error(err)) + } + c.matching.client = client + }) return c.matching.client } @@ -129,6 +153,12 @@ func (c *clients) close() []error { errs = append(errs, conn.Close()) } } + c.matching.clientConnsLock.Lock() + for _, conn := range c.matching.clientConns { + errs = append(errs, conn.Close()) + } + c.matching.clientConns = nil + c.matching.clientConnsLock.Unlock() c.frontend.conn = nil c.history.conn = nil return errs @@ -168,3 +198,144 @@ func (c *clients) tlsConfig(serviceName primitives.ServiceName) (*tls.Config, er } return c.tlsConfigProvider.GetInternodeClientConfig() } + +func (c *clients) newMatchingClient() (matchingservice.MatchingServiceClient, error) { + resolver := newTestMatchingServiceResolver(c.hostsByService[primitives.MatchingService].All) + if resolver.MemberCount() == 0 { + return nil, fmt.Errorf("no matching gRPC hosts configured") + } + + clientProvider := func(clientKey string) (any, func() error, error) { + conn, err := c.newConnToAddress(primitives.MatchingService, clientKey) + if err != nil { + return nil, nil, err + } + c.matching.clientConnsLock.Lock() + c.matching.clientConns = append(c.matching.clientConns, conn) + c.matching.clientConnsLock.Unlock() + return matchingservice.NewMatchingServiceClient(conn), conn.Close, nil + } + client := matching.NewClient( + matching.DefaultTimeout, + matching.DefaultLongPollTimeout, + common.NewClientCache(&testMatchingClientKeyResolver{resolver: resolver}, clientProvider, c.logger), + c.metricsHandler, + c.logger, + matching.NewLoadBalancer(c.namespaceIDToName, c.dc, c.testHooks), + dynamicconfig.MatchingSpreadRoutingBatchSize.Get(c.dc), + resolver, + dynamicconfig.MatchingConnectionCloseDelay.Get(c.dc), + ) + if c.metricsHandler != nil { + client = matching.NewMetricClient(client, c.metricsHandler, c.logger, c.logger) + } + return client, nil +} + +func (c *clients) namespaceIDToName(id namespace.ID) (namespace.Name, error) { + resp, err := c.FrontendClient().DescribeNamespace(NewContext(), &workflowservice.DescribeNamespaceRequest{ + Id: id.String(), + }) + if err != nil { + return "", err + } + return namespace.Name(resp.GetNamespaceInfo().GetName()), nil +} + +type testMatchingClientKeyResolver struct { + resolver *testMatchingServiceResolver +} + +func (r *testMatchingClientKeyResolver) Lookup(key string, index int) (string, error) { + hosts := r.resolver.LookupN(key, index+1) + if len(hosts) == 0 { + return "", membership.ErrInsufficientHosts + } + if index >= len(hosts) { + index %= len(hosts) + } + return hosts[index].GetAddress(), nil +} + +func (r *testMatchingClientKeyResolver) GetAllAddresses() ([]string, error) { + var addresses []string + for _, host := range r.resolver.Members() { + addresses = append(addresses, host.GetAddress()) + } + return addresses, nil +} + +type testMatchingServiceResolver struct { + mu sync.Mutex + hostInfos []membership.HostInfo + listeners map[string]chan<- *membership.ChangedEvent +} + +func newTestMatchingServiceResolver(hosts []string) *testMatchingServiceResolver { + hostInfos := make([]membership.HostInfo, 0, len(hosts)) + for _, host := range hosts { + hostInfos = append(hostInfos, membership.NewHostInfoFromAddress(host)) + } + return &testMatchingServiceResolver{ + hostInfos: hostInfos, + listeners: make(map[string]chan<- *membership.ChangedEvent), + } +} + +func (r *testMatchingServiceResolver) Lookup(key string) (membership.HostInfo, error) { + r.mu.Lock() + defer r.mu.Unlock() + if len(r.hostInfos) == 0 { + return nil, membership.ErrInsufficientHosts + } + hash := int(farm.Fingerprint32([]byte(key))) + return r.hostInfos[hash%len(r.hostInfos)], nil +} + +func (r *testMatchingServiceResolver) LookupN(key string, _ int) []membership.HostInfo { + host, err := r.Lookup(key) + if err != nil { + return nil + } + return []membership.HostInfo{host} +} + +func (r *testMatchingServiceResolver) AddListener(name string, notifyChannel chan<- *membership.ChangedEvent) error { + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.listeners[name]; ok { + return membership.ErrListenerAlreadyExist + } + r.listeners[name] = notifyChannel + return nil +} + +func (r *testMatchingServiceResolver) RemoveListener(name string) error { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.listeners, name) + return nil +} + +func (r *testMatchingServiceResolver) MemberCount() int { + r.mu.Lock() + defer r.mu.Unlock() + return len(r.hostInfos) +} + +func (r *testMatchingServiceResolver) AvailableMemberCount() int { + return r.MemberCount() +} + +func (r *testMatchingServiceResolver) Members() []membership.HostInfo { + r.mu.Lock() + defer r.mu.Unlock() + return append([]membership.HostInfo(nil), r.hostInfos...) +} + +func (r *testMatchingServiceResolver) AvailableMembers() []membership.HostInfo { + return r.Members() +} + +func (r *testMatchingServiceResolver) RequestRefresh() { +} diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 4c2ce857a7d..9dea0864bbb 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -139,6 +139,9 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { impl.logger, impl.hostsByProtocolByService[grpcProtocol], impl.tlsConfigProvider, + dynamicconfig.NewCollection(impl.dcClient, impl.logger), + impl.testHooks, + impl.GetMetricsHandler(), ) _ = testhooks.Set( impl.testHooks, @@ -320,17 +323,6 @@ func (c *TemporalImpl) installHostTestHooks( cleanups = append(cleanups, cleanup) } - addCleanup(testhooks.Set( - c.testHooks, - testhooks.MatchingRawClientCreated, - func(name primitives.ServiceName, client matchingservice.MatchingServiceClient) { - if name == primitives.FrontendService && c.clients.matching.client == nil { - c.clients.matching.client = client - } - }, - testhooks.GlobalScope, - )) - switch serviceName { case primitives.HistoryService: addCleanup(testhooks.Set( From 7cd76eec37b2b690b2326c958199e2da6caca3e6 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 20:31:02 -0700 Subject: [PATCH 07/16] Restore static host validation scope --- temporal/fx.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/temporal/fx.go b/temporal/fx.go index 3c1d420dbd1..a4f240b06b7 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -269,10 +269,10 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { } } - // check that when static hosts are defined, they are defined for all requested hosts + // check that when static hosts are defined, they are defined for all required hosts if len(so.hostsByService) > 0 { - for service := range so.serviceNames { - hosts := so.hostsByService[service] + for _, service := range DefaultServices { + hosts := so.hostsByService[primitives.ServiceName(service)] if len(hosts.All) == 0 { return serverOptionsProvider{}, fmt.Errorf("%w: %v", missingServiceInStaticHosts, service) } From 992bfdb0bf1ec09d7b50d196b4b7b296216b5090 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 20:43:36 -0700 Subject: [PATCH 08/16] Fix lint issues from onebox hooks --- common/persistence/client/fx.go | 7 ++++++- common/resource/fx.go | 4 ++-- tests/testcore/clients.go | 3 ++- tests/testcore/namespace.go | 2 +- tests/xdc/base.go | 10 +++++----- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/common/persistence/client/fx.go b/common/persistence/client/fx.go index a6a9e228d29..065f543b4f7 100644 --- a/common/persistence/client/fx.go +++ b/common/persistence/client/fx.go @@ -240,7 +240,12 @@ func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, err } if executionManager, ok := any(manager).(persistence.ExecutionManager); ok { if hook, ok := testhooks.Get(params.TestHooks, testhooks.HistoryTasksWrittenObserver, testhooks.GlobalScope); ok { - manager = any(newHistoryTasksWrittenObserver(executionManager, hook)).(T) + wrapped, ok := any(newHistoryTasksWrittenObserver(executionManager, hook)).(T) + if !ok { + var nilT T + return nilT, errors.New("history tasks written observer produced unexpected execution manager type") + } + manager = wrapped } } params.Lifecycle.Append(fx.StopHook(manager.Close)) diff --git a/common/resource/fx.go b/common/resource/fx.go index 03ec10713b0..028b90dcc74 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -331,11 +331,11 @@ func MatchingRawClientProvider( clientBean client.Bean, namespaceRegistry namespace.Registry, ) (MatchingRawClient, error) { - client, err := clientBean.GetMatchingClient(namespaceRegistry.GetNamespaceName) + matchingClient, err := clientBean.GetMatchingClient(namespaceRegistry.GetNamespaceName) if err != nil { return nil, err } - return client, nil + return matchingClient, nil } func MatchingClientProvider(matchingRawClient MatchingRawClient) MatchingClient { diff --git a/tests/testcore/clients.go b/tests/testcore/clients.go index 822009eba21..b54d3eeae7e 100644 --- a/tests/testcore/clients.go +++ b/tests/testcore/clients.go @@ -2,6 +2,7 @@ package testcore import ( "crypto/tls" + "errors" "fmt" "sync" @@ -202,7 +203,7 @@ func (c *clients) tlsConfig(serviceName primitives.ServiceName) (*tls.Config, er func (c *clients) newMatchingClient() (matchingservice.MatchingServiceClient, error) { resolver := newTestMatchingServiceResolver(c.hostsByService[primitives.MatchingService].All) if resolver.MemberCount() == 0 { - return nil, fmt.Errorf("no matching gRPC hosts configured") + return nil, errors.New("no matching gRPC hosts configured") } clientProvider := func(clientKey string) (any, func() error, error) { diff --git a/tests/testcore/namespace.go b/tests/testcore/namespace.go index fa67c97a3d7..5a1d8db2251 100644 --- a/tests/testcore/namespace.go +++ b/tests/testcore/namespace.go @@ -48,7 +48,7 @@ func (tc *TestCluster) checkNamespaceAvailable( ) error { hosts := tc.host.hostsByProtocolByService[grpcProtocol][primitives.FrontendService].All if len(hosts) == 0 { - return fmt.Errorf("no frontend gRPC hosts configured") + return errors.New("no frontend gRPC hosts configured") } var errs []error diff --git a/tests/xdc/base.go b/tests/xdc/base.go index 4f445c4a64d..ef3eb490a94 100644 --- a/tests/xdc/base.go +++ b/tests/xdc/base.go @@ -382,7 +382,7 @@ func (s *xdcBaseSuite) promoteNamespace( s.waitForNamespaceAvailable(s.clusters[inClusterIndex], ns, namespaceCacheWaitTime, func(resp *workflowservice.DescribeNamespaceResponse) error { if !resp.GetIsGlobalNamespace() { - return fmt.Errorf("namespace is not global") + return errors.New("namespace is not global") } return nil }) @@ -424,12 +424,12 @@ func (s *xdcBaseSuite) failover( } func (s *xdcBaseSuite) waitForNamespaceAvailable( - cluster *testcore.TestCluster, + testCluster *testcore.TestCluster, ns string, waitTime time.Duration, check testcore.NamespaceAvailabilityCheck, ) { - s.Require().NoError(cluster.WaitForNamespaceAvailable( + s.Require().NoError(testCluster.WaitForNamespaceAvailable( testcore.NewContext(), ns, waitTime, @@ -440,8 +440,8 @@ func (s *xdcBaseSuite) waitForNamespaceAvailable( func compareNamespaceClusters(resp *workflowservice.DescribeNamespaceResponse, want []string) error { got := make([]string, 0, len(resp.GetReplicationConfig().GetClusters())) - for _, cluster := range resp.GetReplicationConfig().GetClusters() { - got = append(got, cluster.GetClusterName()) + for _, namespaceCluster := range resp.GetReplicationConfig().GetClusters() { + got = append(got, namespaceCluster.GetClusterName()) } if !slices.Equal(got, want) { return fmt.Errorf("namespace clusters = %v, want %v", got, want) From 3cff8b0dc2fe9013719612790889255ecc5e58fe Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 20:51:56 -0700 Subject: [PATCH 09/16] Allow static hosts for requested services --- temporal/fx.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/temporal/fx.go b/temporal/fx.go index a4f240b06b7..3c1d420dbd1 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -269,10 +269,10 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { } } - // check that when static hosts are defined, they are defined for all required hosts + // check that when static hosts are defined, they are defined for all requested hosts if len(so.hostsByService) > 0 { - for _, service := range DefaultServices { - hosts := so.hostsByService[primitives.ServiceName(service)] + for service := range so.serviceNames { + hosts := so.hostsByService[service] if len(hosts.All) == 0 { return serverOptionsProvider{}, fmt.Errorf("%w: %v", missingServiceInStaticHosts, service) } From 6927b4a6b5f6cb76fd1f2184107a37bab0d37c13 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 22:22:27 -0700 Subject: [PATCH 10/16] Fix onebox service startup config --- temporal/fx.go | 8 +++++--- temporal/server_option.go | 9 +++++++++ temporal/server_options.go | 1 + tests/testcore/onebox.go | 41 +++++++++++++++++++++++++++++++------- 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/temporal/fx.go b/temporal/fx.go index 3c1d420dbd1..fb4a3066ed9 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -185,9 +185,11 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { } persistenceConfig := so.config.Persistence - err = verifyPersistenceCompatibleVersion(persistenceConfig, so.persistenceServiceResolver, logger) - if err != nil { - return serverOptionsProvider{}, err + if !so.disablePersistenceVersionCheck { + err = verifyPersistenceCompatibleVersion(persistenceConfig, so.persistenceServiceResolver, logger) + if err != nil { + return serverOptionsProvider{}, err + } } stopChan := make(chan any) diff --git a/temporal/server_option.go b/temporal/server_option.go index 03f2f9fd6fd..38e5f10a5f3 100644 --- a/temporal/server_option.go +++ b/temporal/server_option.go @@ -222,6 +222,15 @@ func WithChasmLibraries(libraries ...chasm.Library) ServerOption { }) } +// WithPersistenceVersionCheckDisabled disables startup-time persistence schema +// compatibility checks. This should only be used by callers that own schema +// setup externally and perform equivalent validation before server startup. +func WithPersistenceVersionCheckDisabled() ServerOption { + return applyFunc(func(s *serverOptions) { + s.disablePersistenceVersionCheck = true + }) +} + // WithCustomerMetricsProvider sets a custom implementation of the metrics.MetricsHandler interface // metrics.MetricsHandler is the base interface for publishing metric events func WithCustomMetricsHandler(provider metrics.Handler) ServerOption { diff --git a/temporal/server_options.go b/temporal/server_options.go index 7d8687c46cb..ad38a472f73 100644 --- a/temporal/server_options.go +++ b/temporal/server_options.go @@ -64,6 +64,7 @@ type ( tokenProvider auth.TokenProvider testHooks testhooks.TestHooks chasmLibraries []chasm.Library + disablePersistenceVersionCheck bool } ) diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 9dea0864bbb..87f07819591 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -216,6 +216,20 @@ func (c *TemporalImpl) FrontendHTTPAddress() string { return addrs[rand.Intn(len(addrs))] } +func (c *TemporalImpl) frontendHTTPAddressForHost(serviceName primitives.ServiceName, host string) string { + httpAddrs := c.hostsByProtocolByService[httpProtocol][primitives.FrontendService].All + if serviceName != primitives.FrontendService { + return httpAddrs[0] + } + grpcAddrs := c.hostsByProtocolByService[grpcProtocol][primitives.FrontendService].All + for i, addr := range grpcAddrs { + if addr == host && i < len(httpAddrs) { + return httpAddrs[i] + } + } + return httpAddrs[0] +} + func (c *TemporalImpl) FrontendGRPCAddress() string { return c.hostsByProtocolByService[grpcProtocol][primitives.FrontendService].All[0] } @@ -304,6 +318,7 @@ func (c *TemporalImpl) serverOptionsForHost( temporal.WithPersistenceServiceResolver(resolver.NewNoopResolver()), temporal.WithCustomMetricsHandler(c.GetMetricsHandler()), temporal.WithChasmLibraries(chasmtests.Library), + temporal.WithPersistenceVersionCheckDisabled(), } if c.tlsConfigProvider != nil { options = append(options, temporal.WithTLSConfigFactory(c.tlsConfigProvider)) @@ -357,23 +372,35 @@ func (c *TemporalImpl) configForHost(serviceName primitives.ServiceName, host st BindOnIP: bindIP, GRPCPort: int(port), } + frontendGRPCAddress := c.hostsByProtocolByService[grpcProtocol][primitives.FrontendService].All[0] if serviceName == primitives.FrontendService { - // Set HTTP port and a test HTTP forwarded header - _, httpPort := mustSplitHostPort(c.FrontendHTTPAddress()) - rpcConfig.HTTPPort = int(httpPort) - rpcConfig.HTTPAdditionalForwardedHeaders = []string{ + frontendGRPCAddress = host + } + frontendBindIP, frontendGRPCPort := mustSplitHostPort(frontendGRPCAddress) + _, frontendHTTPPort := mustSplitHostPort(c.frontendHTTPAddressForHost(serviceName, host)) + // Set HTTP port and a test HTTP forwarded header + frontendRPCConfig := config.RPC{ + BindOnIP: frontendBindIP, + GRPCPort: int(frontendGRPCPort), + HTTPPort: int(frontendHTTPPort), + HTTPAdditionalForwardedHeaders: []string{ "this-header-forwarded", "this-header-prefix-forwarded-*", - } + }, } cfg := *c.config cfg.Persistence = copyPersistenceConfig(c.config.Persistence) cfg.Services = map[string]config.Service{ - string(serviceName): { - RPC: rpcConfig, + string(primitives.FrontendService): { + RPC: frontendRPCConfig, }, } + if serviceName != primitives.FrontendService { + cfg.Services[string(serviceName)] = config.Service{ + RPC: rpcConfig, + } + } return &cfg } From e129656c55cb6cba91271fcad153710c32e9d731 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 22:26:37 -0700 Subject: [PATCH 11/16] Skip schema check for custom persistence factories --- temporal/fx.go | 4 +++- temporal/server_option.go | 9 --------- temporal/server_options.go | 1 - tests/testcore/onebox.go | 1 - 4 files changed, 3 insertions(+), 12 deletions(-) diff --git a/temporal/fx.go b/temporal/fx.go index fb4a3066ed9..5ab967f8710 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -185,7 +185,9 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { } persistenceConfig := so.config.Persistence - if !so.disablePersistenceVersionCheck { + // Custom persistence factories own their own schema/version compatibility + // contract; the built-in SQL/Cassandra checker only applies to built-in stores. + if so.customDataStoreFactory == nil { err = verifyPersistenceCompatibleVersion(persistenceConfig, so.persistenceServiceResolver, logger) if err != nil { return serverOptionsProvider{}, err diff --git a/temporal/server_option.go b/temporal/server_option.go index 38e5f10a5f3..03f2f9fd6fd 100644 --- a/temporal/server_option.go +++ b/temporal/server_option.go @@ -222,15 +222,6 @@ func WithChasmLibraries(libraries ...chasm.Library) ServerOption { }) } -// WithPersistenceVersionCheckDisabled disables startup-time persistence schema -// compatibility checks. This should only be used by callers that own schema -// setup externally and perform equivalent validation before server startup. -func WithPersistenceVersionCheckDisabled() ServerOption { - return applyFunc(func(s *serverOptions) { - s.disablePersistenceVersionCheck = true - }) -} - // WithCustomerMetricsProvider sets a custom implementation of the metrics.MetricsHandler interface // metrics.MetricsHandler is the base interface for publishing metric events func WithCustomMetricsHandler(provider metrics.Handler) ServerOption { diff --git a/temporal/server_options.go b/temporal/server_options.go index ad38a472f73..7d8687c46cb 100644 --- a/temporal/server_options.go +++ b/temporal/server_options.go @@ -64,7 +64,6 @@ type ( tokenProvider auth.TokenProvider testHooks testhooks.TestHooks chasmLibraries []chasm.Library - disablePersistenceVersionCheck bool } ) diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 87f07819591..93b0bf70df8 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -318,7 +318,6 @@ func (c *TemporalImpl) serverOptionsForHost( temporal.WithPersistenceServiceResolver(resolver.NewNoopResolver()), temporal.WithCustomMetricsHandler(c.GetMetricsHandler()), temporal.WithChasmLibraries(chasmtests.Library), - temporal.WithPersistenceVersionCheckDisabled(), } if c.tlsConfigProvider != nil { options = append(options, temporal.WithTLSConfigFactory(c.tlsConfigProvider)) From 621cb0f9e4fa51fc430c01ba56c84daca83d58bf Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 22:30:00 -0700 Subject: [PATCH 12/16] Update onebox plan ratings --- plan.md | 277 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 277 insertions(+) create mode 100644 plan.md diff --git a/plan.md b/plan.md new file mode 100644 index 00000000000..075b0d3b75d --- /dev/null +++ b/plan.md @@ -0,0 +1,277 @@ +# Onebox Fx Graph Removal Plan + +## Current state + +PR: https://github.com/temporalio/temporal/pull/10319 + +The branch has reached the mechanical goal: `tests/testcore/onebox.go` no longer builds its own per-service Fx graphs directly and instead starts services through `temporal.NewServerFx`. The remaining work is review and CI validation. + +The public `temporal/` package must stay intentionally small. Keep public additions only when there is a clear non-test/product use case. Prefer existing config/dynamic config first, focused testhooks second, and test-side workarounds third. + +## Checklist status + +- Done: DLQ Fx wrapper removal is split out and merged as #10541. +- Done: onebox client extraction is split out and merged as #10575. +- Done in this branch: `FrontendServiceRefsCreated` removed. Frontend/admin/operator/history/scheduler clients are owned by `tests/testcore/clients.go`; matching now builds a production-style matching client directly in testcore. +- Done in this branch: CHASM refs hooks are typed and focused. `HistoryChasmComponentsCreated` exposes only the engine, visibility manager, and registry. +- Done in this branch: `NamespaceRegistryCreated`, `ChasmRegistryInitializer`, `MatchingRawClientCreated`, broad replication gRPC hooks, and `PersistenceExecutionManagerWrapper` were removed. +- Done in this branch: namespace availability waits, replication stream observation, history-task write observation, CHASM library registration, and matching client construction now use narrower APIs. + +## Recommended PR presentation + +If this all lands in one PR, do not present it primarily as a list of removed hooks. The cleaner review partition is by design boundary: + +### 1. Start onebox through the production server graph + +What to present: + +- `tests/testcore/onebox.go` starts every service host with `temporal.NewServerFx`. +- Onebox still passes only the requested service to `ForServices`. +- Static host validation now requires `Self` only for services requested by `ForServices`. +- Each per-service config still includes frontend HTTP connection details so local system callbacks and HTTP clients do not fall back to port `0`. +- Built-in SQL/Cassandra schema validation is skipped when a custom persistence factory is supplied, because custom factories own their own schema/version compatibility contract. + +Cleanliness rating: 2. This is the core architectural win, but the custom-factory schema-check rule and frontend config duplication deserve focused review. + +### 2. Replace graph-captured service refs with test-owned clients + +What to present: + +- Frontend/admin/operator/history/scheduler clients are constructed directly from known onebox addresses. +- Matching uses `matching.NewClient` with a test static resolver, test dynamic config, metrics, TLS-aware dialing, and frontend-backed namespace ID lookup. + +Cleanliness rating: 2. This removes fx graph access, but the matching client path intentionally mirrors production construction and may need maintenance if matching client dependencies change. + +### 3. Replace broad customization hooks with focused observation hooks + +What to present: + +- Replication recording uses `ReplicationStreamMessageObserver` instead of generic gRPC interceptor/dial-option hooks. +- Task queue recording uses `HistoryTasksWrittenObserver` instead of replacing `persistence.ExecutionManager`. +- CHASM component capture remains focused on the few values tests inspect. + +Cleanliness rating: 3. These hooks expose events or typed refs rather than letting tests replace arbitrary graph pieces. + +### 4. Move test-specific setup to owning suites or real extension points + +What to present: + +- Archival setup is owned by archival tests/config. +- DLQ setup owns its queue manager and SDK client directly. +- CHASM test library registration uses `temporal.WithChasmLibraries`, a real embedded-server extension point. +- Custom archiver factory options remain public only because they are embedded-server configuration, not onebox-only graph access. + +Cleanliness rating: 3. Ownership follows the suite or product extension point that actually needs the behavior. + +### 5. Use observable readiness instead of internal cache refs + +What to present: + +- Namespace registry refs are gone. +- Namespace creation waits through public frontend APIs. +- XDC failover tests explicitly wait for cache refresh only where stale cache state is part of the behavior. + +Cleanliness rating: 3. The test waits on user-visible behavior instead of registry internals. + +### 6. Public API review checklist + +What to present: + +- Keep: `WithTestHooks`, because focused internal testhooks still need a contained injection point. +- Keep: `WithChasmLibraries`, because embedded servers need a pre-start CHASM library registration point. +- Keep: custom archiver factory options, if reviewers agree they are valid embedded-server configuration. +- Avoid: public gRPC interceptor/dial-option options for onebox-only recording. +- Avoid: onebox-only public switches such as an explicit persistence version-check disable option. + +Cleanliness rating: 2. The public surface is small, but every exported option should be defended independently from onebox. + +## Historical audit trail + +### 1. `FrontendServiceRefsCreated` + +Resolution: removed from `service/frontend/fx.go` and `common/testing/testhooks/hooks.go`. + +What changed: frontend/admin/operator clients are now created lazily by `tests/testcore/clients.go` using direct frontend gRPC dialing. History and scheduler clients use the same direct test-client path. Matching now builds a production-style matching client directly in testcore without an fx graph capture hook. + +Options considered: + +- Keep `FrontendServiceRefsCreated`: rejected because it exposed unrelated frontend graph refs. +- Create all test clients directly from gRPC addresses: works for frontend/history/scheduler, but failed for matching. +- Build a production matching client from testcore without fx graph access: selected after testcore gained the static resolver and namespace lookup needed to preserve production routing behavior. + +Obstacle discovered: direct matching gRPC clients failed `TestNexusMatchingTestSuite/TestDispatchNexusTaskOnNonRootPartitionNoForwarding`; they bypass the production matching routing behavior expected by Nexus matching tests. + +Public `temporal/` impact: none. + +### 2. CHASM service refs + +Resolution: replaced broad history refs with typed CHASM-specific hooks. + +What changed: `HistoryServiceRefsCreated` became `HistoryChasmComponentsCreated`, which exposes only `chasm.Engine`, `chasm.VisibilityManager`, and `*chasm.Registry`. `ChasmRegistryInitializer` was removed after CHASM library registration moved to `temporal.WithChasmLibraries`. + +Options considered: + +- Keep `HistoryServiceRefsCreated`: rejected because the name and shape imply arbitrary service graph access. +- Return the three CHASM values as untyped values: rejected because it preserved the same cast-heavy pattern. +- Use a typed struct: selected because it is explicit, compile-time checked, and still narrow. +- Avoid a hook entirely by moving CHASM test-library registration fully into test code: selected via the server option path. + +Obstacle discovered: CHASM registration needs to happen while the production server graph is being initialized. + +Public `temporal/` impact: `WithChasmLibraries`. + +### 3. `PersistenceExecutionManagerWrapper` + +Resolution: replaced the execution-manager wrapper with a focused task-write observer. + +What changed: `TaskQueueRecorder` now observes `HistoryTasksWritten` events emitted after successful persistence writes instead of wrapping `persistence.ExecutionManager`. + +Options considered: + +- Leave the arbitrary wrapper hook: rejected as too broad. +- Replace with `HistoryTasksWritten`: selected because it observes the exact event tests need. +- Make task recording opt-in: preferred and tracked by #10583. + +Obstacle discovered: recording must happen only after successful writes and must cover all task-emitting write paths. + +Public `temporal/` impact: internal testhook only. + +### 4. DLQ Fx wrapper removal + +Resolution: generic `WithFxOptions` usage was removed from `tests/dlq_test.go`. + +What changed: the DLQ suite now creates its own `HistoryTaskQueueManager` from the persistence test factory and dials its own system SDK client when it needs to wait on DLQ job workflows. + +Options considered: + +- Keep `fx.Populate` via `WithFxOptions`: rejected because it preserves generic test graph access. +- Add a DLQ-specific hook: unnecessary because the suite can get the persistence queue manager from existing test cluster state. +- Reuse the production SDK client factory: rejected for this branch because that would require another graph access path; direct SDK dialing is enough for the test. + +Obstacle discovered: `parallelsuite` rejects `s.Require()`, so setup assertions must use the suite’s direct assertion methods. + +Public `temporal/` impact: none. + +## TODO solution evaluation + +Scale: 1 = hacky / fragile, 2 = acceptable but with visible compromise, 3 = clean / narrow / maintainable. + +### 1. `NamespaceRegistryCreated` + +Implemented solution: removed `testhooks.NamespaceRegistryCreated`. Onebox now waits through public namespace APIs with `TestCluster.WaitForNamespaceAvailable`, and xdc failover paths wait for namespace cache refresh where stale cache state matters. + +Cleanliness rating: 3. This uses the public frontend surface and keeps namespace registry internals out of testcore. + +Options considered: + +- Keep collecting registries directly: rejected because it is graph refs access. +- Use `NamespaceCacheRefreshInterval` with a very low interval: possible but indirect and can make tests timing-sensitive. +- Add/use force-refresh-on-read dynamic config: explored, but new namespace reads already fall through to persistence; this is more useful for update/delete/replication staleness than ordinary creation. +- Wait for frontend namespace visibility only: likely enough for new namespace creation because services can read through to persistence on misses. +- New idea: the registry already does persistence read-through on cache miss (with a 1s not-found TTL); after `RegisterNamespace` returns, the namespace exists in persistence and all services will find it on the next miss (at most one TTL interval later). A focused test utility `WaitForNamespaceAvailable` that calls `DescribeNamespace` on each service's frontend gRPC address with short-interval retry would be zero-production-change and make the wait condition explicit. No registry refs, no hook, no config tweak. + +Obstacle discovered: namespace cache behavior is subtle. A force-refresh option is not clearly justified for new namespace creation, and using it broadly would be a product config addition for a test-only concern. + +Public `temporal/` impact: none. + +### 2. Archival setup + +Implemented solution: archival setup stays suite-owned instead of onebox-owned. The branch already had the archival config/factory plumbing needed by `tests/archival_test.go`, so no additional onebox change was required. + +Cleanliness rating: 3. The behavior is owned by the suite that needs it and does not add a onebox-specific hook. + +Options considered: + +- Keep archival setup in onebox: rejected because only archival tests need it. +- Move archival setup into the archival suite: preferred; the suite controls the config it needs. +- Add public archiver factory options to `temporal/`: possible, but should only be kept if independently useful outside onebox tests. +- Use testhooks for archiver factories: possible but probably worse than suite-owned config because archiver setup is test-specific configuration, not an execution hook. + +Public `temporal/` impact: `temporal.WithCustomHistoryArchiverFactory` and `temporal.WithCustomVisibilityArchiverFactory` remain justifiable as embedded-server configuration, not onebox graph access. + +### 3. Replication stream recorder hooks + +Implemented solution: removed generic service interceptor and client dial-option hooks. Added a focused `ReplicationStreamMessageObserver` testhook at the replication stream sender/receiver boundary, and made `ReplicationStreamRecorder` consume that focused message stream. + +Cleanliness rating: 3. The hook is scoped to exactly the observed subsystem and no longer allows arbitrary service gRPC customization. + +Options considered: + +- Public interceptor/dial-option server options: rejected for now; this is onebox-only instrumentation. +- Keep broad `ServiceGrpcInterceptors` and `ServiceClientDialOptions`: works, but too much customization surface. +- Recorder-specific testhooks: likely best if the recorder must stay server-side and client-side. +- Test-side recorder only: insufficient if the test needs to observe server-to-server replication streams. +- New idea: the recorder already filters to exactly one method (`StreamWorkflowReplicationMessages`). Instead of per-service generic interceptor/dial-option hooks, add a focused `ReplicationStreamObserver` hook on the replication manager's stream factory or task executor. The replication manager already owns the stream lifecycle; a single hook there covers all replication traffic without generic per-service gRPC hooks. This replaces two broad hooks with one narrow hook scoped to the replication subsystem. + +Public `temporal/` impact: avoid new public interceptor/dial options. Existing `WithChainedFrontendGrpcInterceptors` should not be expanded for onebox-only needs. + +### 4. `TaskQueueRecorder` final shape + +Implemented solution: removed `testhooks.PersistenceExecutionManagerWrapper`. Added focused `HistoryTasksWrittenObserver` / `HistoryTasksWritten` plumbing in persistence client write paths, and made `TaskQueueRecorder` a sink instead of a persistence manager wrapper. + +Cleanliness rating: 3. The hook observes the event tests need after successful writes and no longer replaces the persistence implementation. + +Options considered: + +- Keep the typed wrapper: acceptable as an intermediate cleanup, not ideal final shape. +- Add `HistoryTasksWritten`: preferred because it observes the event tests need without replacing persistence. +- Make recorder opt-in: preferred so most tests do not pay allocation or recording overhead. +- Move recorder fully into xdc tests: possible for ownership, but the write observation point still needs to be in production/testhook plumbing. + +Public `temporal/` impact: internal testhook only. + +### 5. `ChasmRegistryInitializer` + +Implemented solution: removed `testhooks.ChasmRegistryInitializer`. Added `temporal.WithChasmLibraries(...chasm.Library)` and pass `chasmtests.Library` through onebox server options. + +Cleanliness rating: 3. This is a real embedded-server extension point and registers libraries before startup without exposing registry internals. + +Options considered: + +- Keep typed `ChasmRegistryInitializer`: simple and narrow, but still startup-time test behavior in the production graph. +- Use `HistoryChasmComponentsCreated` to register after creation: likely too late if services need libraries registered during startup. +- Make CHASM test registry setup non-fx/test-owned: attractive, but needs a way to provide the preconfigured registry to the production graph without adding public options. +- New idea: `temporal.WithChasmLibraries(...chasm.Library)` as a public production option, not just a test concern. Production CHASM modules register their libraries via `fx.Invoke` inside each service's fx module, which is not accessible to users embedding Temporal via `temporal.NewServer`. Anyone writing a custom CHASM state machine for an embedded deployment would need this same registration point. Exposing it as a supported option would be independently justified, eliminate the testhook entirely, and give embedded users a proper extension point without requiring them to fork service fx modules. + +Public `temporal/` impact: adds `WithChasmLibraries`, which is broader than onebox but justified for embedded users that need to register custom CHASM libraries. + +### 6. `MatchingRawClientCreated` + +Implemented solution: removed `testhooks.MatchingRawClientCreated`. `tests/testcore/clients.go` now constructs a production-style matching client directly with `matching.NewClient`, a static test resolver over known matching hosts, namespace ID lookup through frontend, test dynamic config, metrics, and TLS-aware gRPC dialing. + +Cleanliness rating: 2. This removes graph access and preserves production routing behavior, but it duplicates enough matching-client construction in testcore that it should be watched if matching client dependencies change. + +Options considered: + +- Direct matching gRPC client: rejected after the Nexus matching regression. +- Recreate production matching raw client in testcore: not currently practical without duplicating client bean and namespace routing setup. +- Keep old frontend refs hook: rejected as too broad. +- Capture only the production matching raw client: selected as the narrowest working path. +- New idea: the matching raw client construction in `clientfactory.go` needs a membership resolver, RPC factory, metrics handler, dynamic config, logger, and a namespace-ID-to-name function. These are all available in testcore via other already-collected state (service addresses, membership ports, onebox logger, etc.). Testcore could construct a static membership resolver from known service addresses and call the existing `matching.NewClient` directly without graph access. This needs more investigation to confirm the full dependency list is available without the production fx graph, but if viable it removes the hook entirely rather than just narrowing it. + +Public `temporal/` impact: internal testhook only. + +### 7. Public `temporal/` package diff + +Implemented solution: + +- Keep `WithTestHooks` for focused internal testhooks. +- Keep `WithChasmLibraries` as a non-test-specific embedded-server extension point. +- Keep custom archiver factory options as embedded-server configuration. +- Relax static-host validation to require self-addresses only for services requested by `ForServices`. +- Skip built-in persistence schema/version checks when a custom persistence factory is supplied, because custom factories own their own compatibility contract. + +Cleanliness rating: 2. The public surface is still small and the onebox-specific skip option was avoided, but `WithTestHooks` remains intentionally test-only and the custom-factory schema-check rule deserves reviewer attention. + +Options considered: + +- Add public options for every onebox customization: rejected; this would turn test-only needs into supported product API. +- Keep only `WithTestHooks`: preferred because it gives test-only branches a contained extension point. +- Use existing config/dynamic config everywhere possible: preferred before testhooks. +- Reach into fx graphs from onebox: rejected as the pattern this PR is trying to remove. + +Obstacle discovered: onebox's custom persistence factories exercise `NewServerFx` through a path where the built-in config-backed SQL/Cassandra schema checker is the wrong owner. The implemented rule keeps production defaults strict while letting custom factories define their own schema contract. + +## Are we close? + +Yes. Onebox now starts services through `temporal.NewServerFx`, broad graph-access hooks are removed, and the remaining compromises are explicit in the ratings above. The root PR should stay draft until CI validates the final branch state. From 95835cc0cb5e2cd8e8747d76656cebb8727e0223 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 22:45:07 -0700 Subject: [PATCH 13/16] Initialize test schema versions --- common/persistence/cassandra/test.go | 38 +++++++++++++++++++ .../persistence/sql/test_sql_persistence.go | 36 ++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/common/persistence/cassandra/test.go b/common/persistence/cassandra/test.go index 1adfc15651f..413cb7cd54a 100644 --- a/common/persistence/cassandra/test.go +++ b/common/persistence/cassandra/test.go @@ -17,6 +17,7 @@ import ( commongocql "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/resolver" + cassandraschema "go.temporal.io/server/schema/cassandra" "go.temporal.io/server/temporal/environment" "go.temporal.io/server/tests/testutils" ) @@ -93,6 +94,7 @@ func (s *TestCluster) SetupTestDatabase() { } s.LoadSchema(path.Join(schemaDir, "temporal", "schema.cql")) + s.LoadSchemaVersion() } // TearDownTestDatabase from PersistenceTestCluster interface @@ -181,6 +183,42 @@ func (s *TestCluster) LoadSchema(schemaFile string) { s.logger.Info("loaded schema") } +// LoadSchemaVersion writes the schema metadata expected by server startup validation. +func (s *TestCluster) LoadSchemaVersion() { + for _, stmt := range []string{ + `CREATE TABLE IF NOT EXISTS schema_version(keyspace_name text PRIMARY KEY, creation_time timestamp, curr_version text, min_compatible_version text);`, + `CREATE TABLE IF NOT EXISTS schema_update_history(year int, month int, update_time timestamp, description text, manifest_md5 text, new_version text, old_version text, PRIMARY KEY ((year, month), update_time));`, + } { + if err := s.session.Query(stmt).Exec(); err != nil { + s.logger.Fatal("LoadSchemaVersion", tag.Error(err)) + } + } + + now := time.Now().UTC() + if err := s.session.Query( + `INSERT into schema_version(keyspace_name, creation_time, curr_version, min_compatible_version) VALUES (?,?,?,?)`, + s.keyspace, + now, + cassandraschema.Version, + cassandraschema.Version, + ).Exec(); err != nil { + s.logger.Fatal("LoadSchemaVersion", tag.Error(err)) + } + if err := s.session.Query( + `INSERT into schema_update_history(year, month, update_time, old_version, new_version, manifest_md5, description) VALUES(?,?,?,?,?,?,?)`, + now.Year(), + int(now.Month()), + now, + "0", + cassandraschema.Version, + "", + "initial version", + ).Exec(); err != nil { + s.logger.Fatal("LoadSchemaVersion", tag.Error(err)) + } + s.logger.Info("loaded schema version", tag.String("version", cassandraschema.Version)) +} + func (s *TestCluster) GetSession() commongocql.Session { return s.session } diff --git a/common/persistence/sql/test_sql_persistence.go b/common/persistence/sql/test_sql_persistence.go index b92589e9a1a..20abae23d6e 100644 --- a/common/persistence/sql/test_sql_persistence.go +++ b/common/persistence/sql/test_sql_persistence.go @@ -83,6 +83,7 @@ func (s *TestCluster) SetupTestDatabase() { } s.LoadSchema(path.Join(schemaDir, "temporal", "schema.sql")) s.LoadSchema(path.Join(schemaDir, "visibility", "schema.sql")) + s.LoadSchemaVersion() } // Config returns the persistence config for connecting to this test cluster @@ -220,3 +221,38 @@ func (s *TestCluster) LoadSchema(schemaFile string) { } s.logger.Info("loaded schema") } + +// LoadSchemaVersion writes the schema metadata expected by server startup validation. +func (s *TestCluster) LoadSchemaVersion() { + var db sqlplugin.AdminDB + var err error + err = backoff.ThrottleRetry( + func() error { + db, err = NewSQLAdminDB(sqlplugin.DbKindMain, &s.cfg, resolver.NewNoopResolver(), log.NewTestLogger(), metrics.NoopMetricsHandler) + return err + }, + backoff.NewExponentialRetryPolicy(time.Second).WithExpirationInterval(time.Minute), + nil, + ) + if err != nil { + panic(err) + } + defer func() { + err := db.Close() + if err != nil { + panic(err) + } + }() + + expectedVersion := db.ExpectedVersion() + if err = db.CreateSchemaVersionTables(); err != nil { + s.logger.Fatal("CreateSchemaVersionTables", tag.Error(err)) + } + if err = db.UpdateSchemaVersion(s.cfg.DatabaseName, expectedVersion, expectedVersion); err != nil { + s.logger.Fatal("UpdateSchemaVersion", tag.Error(err)) + } + if err = db.WriteSchemaUpdateLog("0", expectedVersion, "", "initial version"); err != nil { + s.logger.Fatal("WriteSchemaUpdateLog", tag.Error(err)) + } + s.logger.Info("loaded schema version", tag.String("version", expectedVersion)) +} From 04e7987b1e025eb521cfd4349a917266dde017a0 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sat, 20 Jun 2026 22:46:26 -0700 Subject: [PATCH 14/16] Keep server schema validation enabled --- temporal/fx.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/temporal/fx.go b/temporal/fx.go index 5ab967f8710..3c1d420dbd1 100644 --- a/temporal/fx.go +++ b/temporal/fx.go @@ -185,13 +185,9 @@ func ServerOptionsProvider(opts []ServerOption) (serverOptionsProvider, error) { } persistenceConfig := so.config.Persistence - // Custom persistence factories own their own schema/version compatibility - // contract; the built-in SQL/Cassandra checker only applies to built-in stores. - if so.customDataStoreFactory == nil { - err = verifyPersistenceCompatibleVersion(persistenceConfig, so.persistenceServiceResolver, logger) - if err != nil { - return serverOptionsProvider{}, err - } + err = verifyPersistenceCompatibleVersion(persistenceConfig, so.persistenceServiceResolver, logger) + if err != nil { + return serverOptionsProvider{}, err } stopChan := make(chan any) From 638dd383669ce37d7b43c14e3a58147d5827428f Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sun, 21 Jun 2026 00:18:04 -0700 Subject: [PATCH 15/16] Fix onebox startup CI regressions --- common/persistence/sql/test_sql_persistence.go | 4 ++-- tests/nexus_api_validation_test.go | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/common/persistence/sql/test_sql_persistence.go b/common/persistence/sql/test_sql_persistence.go index 20abae23d6e..85cb79771c9 100644 --- a/common/persistence/sql/test_sql_persistence.go +++ b/common/persistence/sql/test_sql_persistence.go @@ -123,12 +123,12 @@ func (s *TestCluster) CreateDatabase() { nil, ) if err != nil { - panic(err) + s.logger.Fatal("NewSQLAdminDB", tag.Error(err)) } defer func() { err := db.Close() if err != nil { - panic(err) + s.logger.Fatal("Close schema version DB", tag.Error(err)) } }() err = db.CreateDatabase(s.cfg.DatabaseName) diff --git a/tests/nexus_api_validation_test.go b/tests/nexus_api_validation_test.go index bf94d29fc4e..99cdc7ff2c8 100644 --- a/tests/nexus_api_validation_test.go +++ b/tests/nexus_api_validation_test.go @@ -46,7 +46,13 @@ func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_WithNamespaceAndTa requests := capture.Metric("nexus_requests") s.Len(requests, 1) - s.Equal(map[string]string{"namespace": namespace, "method": "StartNexusOperation", "outcome": "namespace_not_found", "nexus_endpoint": "_unknown_"}, requests[0].Tags) + s.Equal(map[string]string{ + "method": "StartNexusOperation", + "namespace": namespace, + "nexus_endpoint": "_unknown_", + "outcome": "namespace_not_found", + "service_name": "frontend", + }, requests[0].Tags) s.Equal(int64(1), requests[0].Value) } From b8f8d2418cc863b0364f7eb3ddb39af5b07b176d Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Sun, 21 Jun 2026 09:04:19 -0700 Subject: [PATCH 16/16] Avoid panic in schema version setup --- common/persistence/sql/test_sql_persistence.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/persistence/sql/test_sql_persistence.go b/common/persistence/sql/test_sql_persistence.go index 85cb79771c9..ecc38c141fa 100644 --- a/common/persistence/sql/test_sql_persistence.go +++ b/common/persistence/sql/test_sql_persistence.go @@ -235,12 +235,12 @@ func (s *TestCluster) LoadSchemaVersion() { nil, ) if err != nil { - panic(err) + s.logger.Fatal("NewSQLAdminDB", tag.Error(err)) } defer func() { err := db.Close() if err != nil { - panic(err) + s.logger.Fatal("Close schema version DB", tag.Error(err)) } }()