Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions common/persistence/cassandra/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
commongocql "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/resolver"
cassandraschema "go.temporal.io/server/schema/cassandra"
"go.temporal.io/server/temporal/environment"
"go.temporal.io/server/tests/testutils"
)
Expand Down Expand Up @@ -93,6 +94,7 @@ func (s *TestCluster) SetupTestDatabase() {
}

s.LoadSchema(path.Join(schemaDir, "temporal", "schema.cql"))
s.LoadSchemaVersion()
}

// TearDownTestDatabase from PersistenceTestCluster interface
Expand Down Expand Up @@ -181,6 +183,42 @@ func (s *TestCluster) LoadSchema(schemaFile string) {
s.logger.Info("loaded schema")
}

// LoadSchemaVersion writes the schema metadata expected by server startup validation.
func (s *TestCluster) LoadSchemaVersion() {
for _, stmt := range []string{
`CREATE TABLE IF NOT EXISTS schema_version(keyspace_name text PRIMARY KEY, creation_time timestamp, curr_version text, min_compatible_version text);`,
`CREATE TABLE IF NOT EXISTS schema_update_history(year int, month int, update_time timestamp, description text, manifest_md5 text, new_version text, old_version text, PRIMARY KEY ((year, month), update_time));`,
} {
if err := s.session.Query(stmt).Exec(); err != nil {
s.logger.Fatal("LoadSchemaVersion", tag.Error(err))
}
}

now := time.Now().UTC()
if err := s.session.Query(
`INSERT into schema_version(keyspace_name, creation_time, curr_version, min_compatible_version) VALUES (?,?,?,?)`,
s.keyspace,
now,
cassandraschema.Version,
cassandraschema.Version,
).Exec(); err != nil {
s.logger.Fatal("LoadSchemaVersion", tag.Error(err))
}
if err := s.session.Query(
`INSERT into schema_update_history(year, month, update_time, old_version, new_version, manifest_md5, description) VALUES(?,?,?,?,?,?,?)`,
now.Year(),
int(now.Month()),
now,
"0",
cassandraschema.Version,
"",
"initial version",
).Exec(); err != nil {
s.logger.Fatal("LoadSchemaVersion", tag.Error(err))
}
s.logger.Info("loaded schema version", tag.String("version", cassandraschema.Version))
}

func (s *TestCluster) GetSession() commongocql.Session {
return s.session
}
27 changes: 23 additions & 4 deletions common/persistence/client/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"go.temporal.io/server/common/quotas"
"go.temporal.io/server/common/resolver"
otel "go.temporal.io/server/common/telemetry"
"go.temporal.io/server/common/testing/testhooks"
"go.uber.org/fx"
)

Expand Down Expand Up @@ -217,9 +218,17 @@ func DataStoreFactoryLifetimeHooks(lc fx.Lifecycle, f persistence.DataStoreFacto
lc.Append(fx.StopHook(f.Close))
}

func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, error)) func(Factory, fx.Lifecycle) (T, error) {
return func(f Factory, lc fx.Lifecycle) (T, error) {
manager, err := newManagerFn(f) // passing receiver (Factory) as first argument.
type managerProviderParams struct {
fx.In

Factory Factory
Lifecycle fx.Lifecycle
TestHooks testhooks.TestHooks `optional:"true"`
}

func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, error)) func(managerProviderParams) (T, error) {
return func(params managerProviderParams) (T, error) {
manager, err := newManagerFn(params.Factory) // passing receiver (Factory) as first argument.
if err != nil {
var unimpl *serviceerror.Unimplemented
if errors.As(err, &unimpl) {
Expand All @@ -229,7 +238,17 @@ func managerProvider[T persistence.Closeable](newManagerFn func(Factory) (T, err
var nilT T
return nilT, err
}
lc.Append(fx.StopHook(manager.Close))
if executionManager, ok := any(manager).(persistence.ExecutionManager); ok {
if hook, ok := testhooks.Get(params.TestHooks, testhooks.HistoryTasksWrittenObserver, testhooks.GlobalScope); ok {
wrapped, ok := any(newHistoryTasksWrittenObserver(executionManager, hook)).(T)
if !ok {
var nilT T
return nilT, errors.New("history tasks written observer produced unexpected execution manager type")
}
manager = wrapped
}
}
params.Lifecycle.Append(fx.StopHook(manager.Close))
return manager, nil
}
}
110 changes: 110 additions & 0 deletions common/persistence/client/history_tasks_observer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package client

import (
"context"

"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/testing/testhooks"
)

type historyTasksWrittenObserver struct {
persistence.ExecutionManager
observe func(testhooks.HistoryTasksWritten)
}

func newHistoryTasksWrittenObserver(
manager persistence.ExecutionManager,
observe func(testhooks.HistoryTasksWritten),
) persistence.ExecutionManager {
return &historyTasksWrittenObserver{
ExecutionManager: manager,
observe: observe,
}
}

func (o *historyTasksWrittenObserver) AddHistoryTasks(
ctx context.Context,
request *persistence.AddHistoryTasksRequest,
) error {
err := o.ExecutionManager.AddHistoryTasks(ctx, request)
if err == nil && request != nil {
o.observe(testhooks.HistoryTasksWritten{
ShardID: request.ShardID,
RangeID: request.RangeID,
NamespaceID: request.NamespaceID,
WorkflowID: request.WorkflowID,
Tasks: request.Tasks,
})
}
return err
}

func (o *historyTasksWrittenObserver) CreateWorkflowExecution(
ctx context.Context,
request *persistence.CreateWorkflowExecutionRequest,
) (*persistence.CreateWorkflowExecutionResponse, error) {
response, err := o.ExecutionManager.CreateWorkflowExecution(ctx, request)
if err == nil && request != nil {
o.observeWorkflowSnapshot(request.ShardID, request.RangeID, &request.NewWorkflowSnapshot)
}
return response, err
}

func (o *historyTasksWrittenObserver) UpdateWorkflowExecution(
ctx context.Context,
request *persistence.UpdateWorkflowExecutionRequest,
) (*persistence.UpdateWorkflowExecutionResponse, error) {
response, err := o.ExecutionManager.UpdateWorkflowExecution(ctx, request)
if err == nil && request != nil {
o.observeWorkflowMutation(request.ShardID, request.RangeID, &request.UpdateWorkflowMutation)
o.observeWorkflowSnapshot(request.ShardID, request.RangeID, request.NewWorkflowSnapshot)
}
return response, err
}

func (o *historyTasksWrittenObserver) ConflictResolveWorkflowExecution(
ctx context.Context,
request *persistence.ConflictResolveWorkflowExecutionRequest,
) (*persistence.ConflictResolveWorkflowExecutionResponse, error) {
response, err := o.ExecutionManager.ConflictResolveWorkflowExecution(ctx, request)
if err == nil && request != nil {
o.observeWorkflowSnapshot(request.ShardID, request.RangeID, &request.ResetWorkflowSnapshot)
o.observeWorkflowSnapshot(request.ShardID, request.RangeID, request.NewWorkflowSnapshot)
o.observeWorkflowMutation(request.ShardID, request.RangeID, request.CurrentWorkflowMutation)
}
return response, err
}

func (o *historyTasksWrittenObserver) observeWorkflowSnapshot(
shardID int32,
rangeID int64,
snapshot *persistence.WorkflowSnapshot,
) {
if snapshot == nil || snapshot.ExecutionInfo == nil {
return
}
o.observe(testhooks.HistoryTasksWritten{
ShardID: shardID,
RangeID: rangeID,
NamespaceID: snapshot.ExecutionInfo.NamespaceId,
WorkflowID: snapshot.ExecutionInfo.WorkflowId,
Tasks: snapshot.Tasks,
})
}

func (o *historyTasksWrittenObserver) observeWorkflowMutation(
shardID int32,
rangeID int64,
mutation *persistence.WorkflowMutation,
) {
if mutation == nil || mutation.ExecutionInfo == nil {
return
}
o.observe(testhooks.HistoryTasksWritten{
ShardID: shardID,
RangeID: rangeID,
NamespaceID: mutation.ExecutionInfo.NamespaceId,
WorkflowID: mutation.ExecutionInfo.WorkflowId,
Tasks: mutation.Tasks,
})
}
40 changes: 38 additions & 2 deletions common/persistence/sql/test_sql_persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func (s *TestCluster) SetupTestDatabase() {
}
s.LoadSchema(path.Join(schemaDir, "temporal", "schema.sql"))
s.LoadSchema(path.Join(schemaDir, "visibility", "schema.sql"))
s.LoadSchemaVersion()
}

// Config returns the persistence config for connecting to this test cluster
Expand Down Expand Up @@ -122,12 +123,12 @@ func (s *TestCluster) CreateDatabase() {
nil,
)
if err != nil {
panic(err)
s.logger.Fatal("NewSQLAdminDB", tag.Error(err))
}
defer func() {
err := db.Close()
if err != nil {
panic(err)
s.logger.Fatal("Close schema version DB", tag.Error(err))
}
}()
err = db.CreateDatabase(s.cfg.DatabaseName)
Expand Down Expand Up @@ -220,3 +221,38 @@ func (s *TestCluster) LoadSchema(schemaFile string) {
}
s.logger.Info("loaded schema")
}

// LoadSchemaVersion writes the schema metadata expected by server startup validation.
func (s *TestCluster) LoadSchemaVersion() {
var db sqlplugin.AdminDB
var err error
err = backoff.ThrottleRetry(
func() error {
db, err = NewSQLAdminDB(sqlplugin.DbKindMain, &s.cfg, resolver.NewNoopResolver(), log.NewTestLogger(), metrics.NoopMetricsHandler)
return err
},
backoff.NewExponentialRetryPolicy(time.Second).WithExpirationInterval(time.Minute),
nil,
)
if err != nil {
s.logger.Fatal("NewSQLAdminDB", tag.Error(err))
}
defer func() {
err := db.Close()
if err != nil {
s.logger.Fatal("Close schema version DB", tag.Error(err))
}
}()

expectedVersion := db.ExpectedVersion()
if err = db.CreateSchemaVersionTables(); err != nil {
s.logger.Fatal("CreateSchemaVersionTables", tag.Error(err))
}
if err = db.UpdateSchemaVersion(s.cfg.DatabaseName, expectedVersion, expectedVersion); err != nil {
s.logger.Fatal("UpdateSchemaVersion", tag.Error(err))
}
if err = db.WriteSchemaUpdateLog("0", expectedVersion, "", "initial version"); err != nil {
s.logger.Fatal("WriteSchemaUpdateLog", tag.Error(err))
}
s.logger.Info("loaded schema version", tag.String("version", expectedVersion))
}
13 changes: 10 additions & 3 deletions common/resource/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ func SearchAttributeValidatorProvider(
type NamespaceRegistryParams struct {
fx.In

ServiceName primitives.ServiceName `optional:"true"`
Logger log.SnTaggedLogger
MetricsHandler metrics.Handler
ClusterMetadata cluster.Metadata
Expand All @@ -235,7 +236,7 @@ type NamespaceRegistryParams struct {
}

func NamespaceRegistryProvider(params NamespaceRegistryParams) namespace.Registry {
return nsregistry.NewRegistry(
registry := nsregistry.NewRegistry(
params.MetadataManager,
params.ClusterMetadata.IsGlobalNamespaceEnabled(),
params.ClusterMetadata.GetCurrentClusterName(),
Expand All @@ -246,6 +247,7 @@ func NamespaceRegistryProvider(params NamespaceRegistryParams) namespace.Registr
params.ReplicationResolverFactory,
params.NamespaceStateChangedFn,
)
return registry
}

func ClientFactoryProvider(
Expand Down Expand Up @@ -329,7 +331,11 @@ func MatchingRawClientProvider(
clientBean client.Bean,
namespaceRegistry namespace.Registry,
) (MatchingRawClient, error) {
return clientBean.GetMatchingClient(namespaceRegistry.GetNamespaceName)
matchingClient, err := clientBean.GetMatchingClient(namespaceRegistry.GetNamespaceName)
if err != nil {
return nil, err
}
return matchingClient, nil
}

func MatchingClientProvider(matchingRawClient MatchingRawClient) MatchingClient {
Expand Down Expand Up @@ -406,10 +412,11 @@ func PerServiceDialOptionsProvider(
) map[primitives.ServiceName][]grpc.DialOption {
trailerInterceptor := interceptor.TrailerToContextMetadataInterceptor(logger)
dialOpt := grpc.WithChainUnaryInterceptor(trailerInterceptor)
return map[primitives.ServiceName][]grpc.DialOption{
options := map[primitives.ServiceName][]grpc.DialOption{
primitives.HistoryService: {dialOpt},
primitives.MatchingService: {dialOpt},
}
return options
}

func RPCFactoryProvider(
Expand Down
39 changes: 39 additions & 0 deletions common/testing/testhooks/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,44 @@ import (
"go.temporal.io/server/api/historyservice/v1"
persistencespb "go.temporal.io/server/api/persistence/v1"
replicationspb "go.temporal.io/server/api/replication/v1"
"go.temporal.io/server/chasm"
"go.temporal.io/server/common/namespace"
historytasks "go.temporal.io/server/service/history/tasks"
"google.golang.org/protobuf/proto"
)

type (
HistoryChasmComponents struct {
Engine chasm.Engine
VisibilityManager chasm.VisibilityManager
Registry *chasm.Registry
}

ReplicationStreamMessageDirection string

ReplicationStreamMessage struct {
Method string
Direction ReplicationStreamMessageDirection
ClusterName string
TargetAddress string
Message proto.Message
IsStreamCall bool
}

HistoryTasksWritten struct {
ShardID int32
RangeID int64
NamespaceID string
WorkflowID string
Tasks map[historytasks.Category][]historytasks.Task
}
)

const (
ReplicationStreamDirectionSend ReplicationStreamMessageDirection = "send"
ReplicationStreamDirectionRecv ReplicationStreamMessageDirection = "recv"
ReplicationStreamDirectionServerSend ReplicationStreamMessageDirection = "server_send"
ReplicationStreamDirectionServerRecv ReplicationStreamMessageDirection = "server_recv"
)

// Test hook keys with their return type and scope.
Expand All @@ -28,6 +64,9 @@ var (
HistoryTransferTaskInterceptor = newKey[func(historytasks.Task, func()), namespace.ID]()
HistoryDLQTaskDeleteInterceptor = newKey[func(context.Context, *historyservice.DeleteDLQTasksRequest, func(context.Context, *historyservice.DeleteDLQTasksRequest) (*historyservice.DeleteDLQTasksResponse, error)) (*historyservice.DeleteDLQTasksResponse, error), global]()
NamespaceReplicationTaskInterceptor = newKey[func(context.Context, *replicationspb.NamespaceTaskAttributes, func() error) error, namespace.Name]()
HistoryTasksWrittenObserver = newKey[func(HistoryTasksWritten), global]()
ReplicationStreamMessageObserver = newKey[func(ReplicationStreamMessage), global]()
HistoryChasmComponentsCreated = newKey[func(HistoryChasmComponents), global]()
)

// keyID is a unique identifier for a key, used as a map key.
Expand Down
Loading
Loading