mirror of
https://codeberg.org/superseriousbusiness/gotosocial.git
synced 2024-12-26 19:08:19 +03:00
458 lines
10 KiB
Go
458 lines
10 KiB
Go
package migrate
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"time"
|
|
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
const (
|
|
defaultTable = "bun_migrations"
|
|
defaultLocksTable = "bun_migration_locks"
|
|
)
|
|
|
|
type MigratorOption func(m *Migrator)
|
|
|
|
// WithTableName overrides default migrations table name.
|
|
func WithTableName(table string) MigratorOption {
|
|
return func(m *Migrator) {
|
|
m.table = table
|
|
}
|
|
}
|
|
|
|
// WithLocksTableName overrides default migration locks table name.
|
|
func WithLocksTableName(table string) MigratorOption {
|
|
return func(m *Migrator) {
|
|
m.locksTable = table
|
|
}
|
|
}
|
|
|
|
// WithMarkAppliedOnSuccess sets the migrator to only mark migrations as applied/unapplied
|
|
// when their up/down is successful.
|
|
func WithMarkAppliedOnSuccess(enabled bool) MigratorOption {
|
|
return func(m *Migrator) {
|
|
m.markAppliedOnSuccess = enabled
|
|
}
|
|
}
|
|
|
|
type Migrator struct {
|
|
db *bun.DB
|
|
migrations *Migrations
|
|
|
|
ms MigrationSlice
|
|
|
|
table string
|
|
locksTable string
|
|
markAppliedOnSuccess bool
|
|
}
|
|
|
|
func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Migrator {
|
|
m := &Migrator{
|
|
db: db,
|
|
migrations: migrations,
|
|
|
|
ms: migrations.ms,
|
|
|
|
table: defaultTable,
|
|
locksTable: defaultLocksTable,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(m)
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (m *Migrator) DB() *bun.DB {
|
|
return m.db
|
|
}
|
|
|
|
// MigrationsWithStatus returns migrations with status in ascending order.
|
|
func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) {
|
|
sorted, _, err := m.migrationsWithStatus(ctx)
|
|
return sorted, err
|
|
}
|
|
|
|
func (m *Migrator) migrationsWithStatus(ctx context.Context) (MigrationSlice, int64, error) {
|
|
sorted := m.migrations.Sorted()
|
|
|
|
applied, err := m.AppliedMigrations(ctx)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
appliedMap := migrationMap(applied)
|
|
for i := range sorted {
|
|
m1 := &sorted[i]
|
|
if m2, ok := appliedMap[m1.Name]; ok {
|
|
m1.ID = m2.ID
|
|
m1.GroupID = m2.GroupID
|
|
m1.MigratedAt = m2.MigratedAt
|
|
}
|
|
}
|
|
|
|
return sorted, applied.LastGroupID(), nil
|
|
}
|
|
|
|
func (m *Migrator) Init(ctx context.Context) error {
|
|
if _, err := m.db.NewCreateTable().
|
|
Model((*Migration)(nil)).
|
|
ModelTableExpr(m.table).
|
|
IfNotExists().
|
|
Exec(ctx); err != nil {
|
|
return err
|
|
}
|
|
if _, err := m.db.NewCreateTable().
|
|
Model((*migrationLock)(nil)).
|
|
ModelTableExpr(m.locksTable).
|
|
IfNotExists().
|
|
Exec(ctx); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrator) Reset(ctx context.Context) error {
|
|
if _, err := m.db.NewDropTable().
|
|
Model((*Migration)(nil)).
|
|
ModelTableExpr(m.table).
|
|
IfExists().
|
|
Exec(ctx); err != nil {
|
|
return err
|
|
}
|
|
if _, err := m.db.NewDropTable().
|
|
Model((*migrationLock)(nil)).
|
|
ModelTableExpr(m.locksTable).
|
|
IfExists().
|
|
Exec(ctx); err != nil {
|
|
return err
|
|
}
|
|
return m.Init(ctx)
|
|
}
|
|
|
|
// Migrate runs unapplied migrations. If a migration fails, migrate immediately exits.
|
|
func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
|
|
cfg := newMigrationConfig(opts)
|
|
|
|
if err := m.validate(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
migrations, lastGroupID, err := m.migrationsWithStatus(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
migrations = migrations.Unapplied()
|
|
|
|
group := new(MigrationGroup)
|
|
if len(migrations) == 0 {
|
|
return group, nil
|
|
}
|
|
group.ID = lastGroupID + 1
|
|
|
|
for i := range migrations {
|
|
migration := &migrations[i]
|
|
migration.GroupID = group.ID
|
|
|
|
if !m.markAppliedOnSuccess {
|
|
if err := m.MarkApplied(ctx, migration); err != nil {
|
|
return group, err
|
|
}
|
|
}
|
|
|
|
group.Migrations = migrations[:i+1]
|
|
|
|
if !cfg.nop && migration.Up != nil {
|
|
if err := migration.Up(ctx, m.db); err != nil {
|
|
return group, err
|
|
}
|
|
}
|
|
|
|
if m.markAppliedOnSuccess {
|
|
if err := m.MarkApplied(ctx, migration); err != nil {
|
|
return group, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return group, nil
|
|
}
|
|
|
|
func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
|
|
cfg := newMigrationConfig(opts)
|
|
|
|
if err := m.validate(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
migrations, err := m.MigrationsWithStatus(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
lastGroup := migrations.LastGroup()
|
|
|
|
for i := len(lastGroup.Migrations) - 1; i >= 0; i-- {
|
|
migration := &lastGroup.Migrations[i]
|
|
|
|
if !m.markAppliedOnSuccess {
|
|
if err := m.MarkUnapplied(ctx, migration); err != nil {
|
|
return lastGroup, err
|
|
}
|
|
}
|
|
|
|
if !cfg.nop && migration.Down != nil {
|
|
if err := migration.Down(ctx, m.db); err != nil {
|
|
return lastGroup, err
|
|
}
|
|
}
|
|
|
|
if m.markAppliedOnSuccess {
|
|
if err := m.MarkUnapplied(ctx, migration); err != nil {
|
|
return lastGroup, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return lastGroup, nil
|
|
}
|
|
|
|
type goMigrationConfig struct {
|
|
packageName string
|
|
goTemplate string
|
|
}
|
|
|
|
type GoMigrationOption func(cfg *goMigrationConfig)
|
|
|
|
func WithPackageName(name string) GoMigrationOption {
|
|
return func(cfg *goMigrationConfig) {
|
|
cfg.packageName = name
|
|
}
|
|
}
|
|
|
|
func WithGoTemplate(template string) GoMigrationOption {
|
|
return func(cfg *goMigrationConfig) {
|
|
cfg.goTemplate = template
|
|
}
|
|
}
|
|
|
|
// CreateGoMigration creates a Go migration file.
|
|
func (m *Migrator) CreateGoMigration(
|
|
ctx context.Context, name string, opts ...GoMigrationOption,
|
|
) (*MigrationFile, error) {
|
|
cfg := &goMigrationConfig{
|
|
packageName: "migrations",
|
|
goTemplate: goTemplate,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(cfg)
|
|
}
|
|
|
|
name, err := genMigrationName(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fname := name + ".go"
|
|
fpath := filepath.Join(m.migrations.getDirectory(), fname)
|
|
content := fmt.Sprintf(cfg.goTemplate, cfg.packageName)
|
|
|
|
if err := os.WriteFile(fpath, []byte(content), 0o644); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mf := &MigrationFile{
|
|
Name: fname,
|
|
Path: fpath,
|
|
Content: content,
|
|
}
|
|
return mf, nil
|
|
}
|
|
|
|
// CreateTxSQLMigration creates transactional up and down SQL migration files.
|
|
func (m *Migrator) CreateTxSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
|
|
name, err := genMigrationName(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
up, err := m.createSQL(ctx, name+".tx.up.sql", true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
down, err := m.createSQL(ctx, name+".tx.down.sql", true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return []*MigrationFile{up, down}, nil
|
|
}
|
|
|
|
// CreateSQLMigrations creates up and down SQL migration files.
|
|
func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
|
|
name, err := genMigrationName(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
up, err := m.createSQL(ctx, name+".up.sql", false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
down, err := m.createSQL(ctx, name+".down.sql", false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return []*MigrationFile{up, down}, nil
|
|
}
|
|
|
|
func (m *Migrator) createSQL(_ context.Context, fname string, transactional bool) (*MigrationFile, error) {
|
|
fpath := filepath.Join(m.migrations.getDirectory(), fname)
|
|
|
|
template := sqlTemplate
|
|
if transactional {
|
|
template = transactionalSQLTemplate
|
|
}
|
|
|
|
if err := os.WriteFile(fpath, []byte(template), 0o644); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mf := &MigrationFile{
|
|
Name: fname,
|
|
Path: fpath,
|
|
Content: goTemplate,
|
|
}
|
|
return mf, nil
|
|
}
|
|
|
|
var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`)
|
|
|
|
func genMigrationName(name string) (string, error) {
|
|
const timeFormat = "20060102150405"
|
|
|
|
if name == "" {
|
|
return "", errors.New("migrate: migration name can't be empty")
|
|
}
|
|
if !nameRE.MatchString(name) {
|
|
return "", fmt.Errorf("migrate: invalid migration name: %q", name)
|
|
}
|
|
|
|
version := time.Now().UTC().Format(timeFormat)
|
|
return fmt.Sprintf("%s_%s", version, name), nil
|
|
}
|
|
|
|
// MarkApplied marks the migration as applied (completed).
|
|
func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error {
|
|
_, err := m.db.NewInsert().Model(migration).
|
|
ModelTableExpr(m.table).
|
|
Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
// MarkUnapplied marks the migration as unapplied (new).
|
|
func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error {
|
|
_, err := m.db.NewDelete().
|
|
Model(migration).
|
|
ModelTableExpr(m.table).
|
|
Where("id = ?", migration.ID).
|
|
Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (m *Migrator) TruncateTable(ctx context.Context) error {
|
|
_, err := m.db.NewTruncateTable().
|
|
Model((*Migration)(nil)).
|
|
ModelTableExpr(m.table).
|
|
Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
// MissingMigrations returns applied migrations that can no longer be found.
|
|
func (m *Migrator) MissingMigrations(ctx context.Context) (MigrationSlice, error) {
|
|
applied, err := m.AppliedMigrations(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
existing := migrationMap(m.migrations.ms)
|
|
for i := len(applied) - 1; i >= 0; i-- {
|
|
m := &applied[i]
|
|
if _, ok := existing[m.Name]; ok {
|
|
applied = append(applied[:i], applied[i+1:]...)
|
|
}
|
|
}
|
|
|
|
return applied, nil
|
|
}
|
|
|
|
// AppliedMigrations selects applied (applied) migrations in descending order.
|
|
func (m *Migrator) AppliedMigrations(ctx context.Context) (MigrationSlice, error) {
|
|
var ms MigrationSlice
|
|
if err := m.db.NewSelect().
|
|
ColumnExpr("*").
|
|
Model(&ms).
|
|
ModelTableExpr(m.table).
|
|
Scan(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return ms, nil
|
|
}
|
|
|
|
func (m *Migrator) formattedTableName(db *bun.DB) string {
|
|
return db.Formatter().FormatQuery(m.table)
|
|
}
|
|
|
|
func (m *Migrator) validate() error {
|
|
if len(m.ms) == 0 {
|
|
return errors.New("migrate: there are no migrations")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
type migrationLock struct {
|
|
ID int64 `bun:",pk,autoincrement"`
|
|
TableName string `bun:",unique"`
|
|
}
|
|
|
|
func (m *Migrator) Lock(ctx context.Context) error {
|
|
lock := &migrationLock{
|
|
TableName: m.formattedTableName(m.db),
|
|
}
|
|
if _, err := m.db.NewInsert().
|
|
Model(lock).
|
|
ModelTableExpr(m.locksTable).
|
|
Exec(ctx); err != nil {
|
|
return fmt.Errorf("migrate: migrations table is already locked (%w)", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrator) Unlock(ctx context.Context) error {
|
|
tableName := m.formattedTableName(m.db)
|
|
_, err := m.db.NewDelete().
|
|
Model((*migrationLock)(nil)).
|
|
ModelTableExpr(m.locksTable).
|
|
Where("? = ?", bun.Ident("table_name"), tableName).
|
|
Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func migrationMap(ms MigrationSlice) map[string]*Migration {
|
|
mp := make(map[string]*Migration)
|
|
for i := range ms {
|
|
m := &ms[i]
|
|
mp[m.Name] = m
|
|
}
|
|
return mp
|
|
}
|