diff --git a/common/persistence/cassandra/test.go b/common/persistence/cassandra/test.go index 1adfc15651f..1755faefae7 100644 --- a/common/persistence/cassandra/test.go +++ b/common/persistence/cassandra/test.go @@ -17,12 +17,32 @@ import ( commongocql "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/resolver" + cassandraschema "go.temporal.io/server/schema/cassandra" "go.temporal.io/server/temporal/environment" "go.temporal.io/server/tests/testutils" ) const ( testSchemaDir = "schema/cassandra/" + + createSchemaVersionTableCQL = `CREATE TABLE IF NOT EXISTS schema_version(keyspace_name text PRIMARY KEY, ` + + `creation_time timestamp, ` + + `curr_version text, ` + + `min_compatible_version text);` + + createSchemaUpdateHistoryTableCQL = `CREATE TABLE IF NOT EXISTS schema_update_history(` + + `year int, ` + + `month int, ` + + `update_time timestamp, ` + + `description text, ` + + `manifest_md5 text, ` + + `new_version text, ` + + `old_version text, ` + + `PRIMARY KEY ((year, month), update_time));` + + writeSchemaVersionCQL = `INSERT into schema_version(keyspace_name, creation_time, curr_version, min_compatible_version) VALUES (?,?,?,?)` + + writeSchemaUpdateHistoryCQL = `INSERT into schema_update_history(year, month, update_time, old_version, new_version, manifest_md5, description) VALUES(?,?,?,?,?,?,?)` ) // TestCluster allows executing cassandra operations in testing. @@ -93,6 +113,7 @@ func (s *TestCluster) SetupTestDatabase() { } s.LoadSchema(path.Join(schemaDir, "temporal", "schema.cql")) + s.loadSchemaVersion() } // TearDownTestDatabase from PersistenceTestCluster interface @@ -181,6 +202,38 @@ func (s *TestCluster) LoadSchema(schemaFile string) { s.logger.Info("loaded schema") } +func (s *TestCluster) loadSchemaVersion() { + s.createSchemaVersionTables() + s.updateSchemaVersion(cassandraschema.Version, cassandraschema.Version) + s.writeSchemaUpdateLog("0", cassandraschema.Version, "", "initial version") + s.logger.Info("loaded schema version", tag.String("version", cassandraschema.Version)) +} + +func (s *TestCluster) createSchemaVersionTables() { + s.execSchemaVersionQuery(createSchemaVersionTableCQL) + s.execSchemaVersionQuery(createSchemaUpdateHistoryTableCQL) +} + +func (s *TestCluster) updateSchemaVersion(newVersion string, minCompatibleVersion string) { + now := time.Now().UTC() + s.execSchemaVersionQuery( + writeSchemaVersionCQL, + s.keyspace, now, newVersion, minCompatibleVersion) +} + +func (s *TestCluster) writeSchemaUpdateLog(oldVersion string, newVersion string, manifestMD5 string, description string) { + now := time.Now().UTC() + s.execSchemaVersionQuery( + writeSchemaUpdateHistoryCQL, + now.Year(), int(now.Month()), now, oldVersion, newVersion, manifestMD5, description) +} + +func (s *TestCluster) execSchemaVersionQuery(stmt string, args ...any) { + if err := s.session.Query(stmt, args...).Exec(); err != nil { + s.logger.Fatal("loadSchemaVersion", tag.Error(err)) + } +} + func (s *TestCluster) GetSession() commongocql.Session { return s.session } diff --git a/common/persistence/sql/test_sql_persistence.go b/common/persistence/sql/test_sql_persistence.go index b92589e9a1a..557afda9cf2 100644 --- a/common/persistence/sql/test_sql_persistence.go +++ b/common/persistence/sql/test_sql_persistence.go @@ -83,6 +83,7 @@ func (s *TestCluster) SetupTestDatabase() { } s.LoadSchema(path.Join(schemaDir, "temporal", "schema.sql")) s.LoadSchema(path.Join(schemaDir, "visibility", "schema.sql")) + s.loadSchemaVersion() } // Config returns the persistence config for connecting to this test cluster @@ -111,26 +112,9 @@ func (s *TestCluster) CreateDatabase() { cfg2.DatabaseName = "" } - var db sqlplugin.AdminDB - var err error - err = backoff.ThrottleRetry( - func() error { - db, err = NewSQLAdminDB(sqlplugin.DbKindUnknown, &cfg2, resolver.NewNoopResolver(), log.NewTestLogger(), metrics.NoopMetricsHandler) - return err - }, - backoff.NewExponentialRetryPolicy(time.Second).WithExpirationInterval(time.Minute), - nil, - ) - if err != nil { - panic(err) - } - defer func() { - err := db.Close() - if err != nil { - panic(err) - } - }() - err = db.CreateDatabase(s.cfg.DatabaseName) + db := s.newAdminDB(sqlplugin.DbKindUnknown, &cfg2) + defer s.closeAdminDB(db) + err := db.CreateDatabase(s.cfg.DatabaseName) if err != nil { panic(err) } @@ -154,18 +138,9 @@ func (s *TestCluster) DropDatabase() { // NOTE need to connect with empty name to drop the database cfg2.DatabaseName = "" - db, err := NewSQLAdminDB(sqlplugin.DbKindUnknown, &cfg2, resolver.NewNoopResolver(), log.NewTestLogger(), metrics.NoopMetricsHandler) - if err != nil { - panic(err) - } - defer func() { - err := db.Close() - if err != nil { - panic(err) - } - }() - err = db.DropDatabase(s.cfg.DatabaseName) - if err != nil { + db := s.newAdminDB(sqlplugin.DbKindUnknown, &cfg2) + defer s.closeAdminDB(db) + if err := db.DropDatabase(s.cfg.DatabaseName); err != nil { panic(err) } s.logger.Info("dropped database", tag.String("database", s.cfg.DatabaseName)) @@ -183,24 +158,8 @@ func (s *TestCluster) LoadSchema(schemaFile string) { ) } - var db sqlplugin.AdminDB - err = backoff.ThrottleRetry( - func() error { - db, err = NewSQLAdminDB(sqlplugin.DbKindUnknown, &s.cfg, resolver.NewNoopResolver(), log.NewTestLogger(), metrics.NoopMetricsHandler) - return err - }, - backoff.NewExponentialRetryPolicy(time.Second).WithExpirationInterval(time.Minute), - nil, - ) - if err != nil { - panic(err) - } - defer func() { - err := db.Close() - if err != nil { - panic(err) - } - }() + db := s.newAdminDB(sqlplugin.DbKindUnknown, &s.cfg) + defer s.closeAdminDB(db) if rewriter, ok := db.(sqlplugin.SchemaStatementRewriter); ok { statements = rewriter.RewriteSchemaStatements(statements) @@ -220,3 +179,43 @@ func (s *TestCluster) LoadSchema(schemaFile string) { } s.logger.Info("loaded schema") } + +func (s *TestCluster) loadSchemaVersion() { + db := s.newAdminDB(sqlplugin.DbKindMain, &s.cfg) + defer s.closeAdminDB(db) + + expectedVersion := db.ExpectedVersion() + if err := db.CreateSchemaVersionTables(); err != nil { + s.logger.Fatal("CreateSchemaVersionTables", tag.Error(err)) + } + if err := db.UpdateSchemaVersion(s.cfg.DatabaseName, expectedVersion, expectedVersion); err != nil { + s.logger.Fatal("UpdateSchemaVersion", tag.Error(err)) + } + if err := db.WriteSchemaUpdateLog("0", expectedVersion, "", "initial version"); err != nil { + s.logger.Fatal("WriteSchemaUpdateLog", tag.Error(err)) + } + s.logger.Info("loaded schema version", tag.String("version", expectedVersion)) +} + +func (s *TestCluster) newAdminDB(kind sqlplugin.DbKind, cfg *config.SQL) sqlplugin.AdminDB { + var db sqlplugin.AdminDB + var err error + err = backoff.ThrottleRetry( + func() error { + db, err = NewSQLAdminDB(kind, cfg, resolver.NewNoopResolver(), log.NewTestLogger(), metrics.NoopMetricsHandler) + return err + }, + backoff.NewExponentialRetryPolicy(time.Second).WithExpirationInterval(time.Minute), + nil, + ) + if err != nil { + s.logger.Fatal("NewSQLAdminDB", tag.Error(err)) + } + return db +} + +func (s *TestCluster) closeAdminDB(db sqlplugin.AdminDB) { + if err := db.Close(); err != nil { + s.logger.Fatal("Close schema DB", tag.Error(err)) + } +}