2021-08-25 15:34:33 +02:00
|
|
|
package schema
|
2021-08-12 21:03:24 +02:00
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"reflect"
|
|
|
|
"sync"
|
2024-11-08 13:51:23 +00:00
|
|
|
|
|
|
|
"github.com/puzpuzpuz/xsync/v3"
|
2021-08-12 21:03:24 +02:00
|
|
|
)
|
|
|
|
|
2021-08-25 15:34:33 +02:00
|
|
|
type Tables struct {
|
|
|
|
dialect Dialect
|
2021-08-12 21:03:24 +02:00
|
|
|
|
2024-11-08 13:51:23 +00:00
|
|
|
mu sync.Mutex
|
|
|
|
tables *xsync.MapOf[reflect.Type, *Table]
|
|
|
|
|
|
|
|
inProgress map[reflect.Type]*Table
|
2021-08-12 21:03:24 +02:00
|
|
|
}
|
|
|
|
|
2021-08-25 15:34:33 +02:00
|
|
|
func NewTables(dialect Dialect) *Tables {
|
|
|
|
return &Tables{
|
|
|
|
dialect: dialect,
|
2024-11-08 13:51:23 +00:00
|
|
|
tables: xsync.NewMapOf[reflect.Type, *Table](),
|
|
|
|
inProgress: make(map[reflect.Type]*Table),
|
2021-08-12 21:03:24 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-08-25 15:34:33 +02:00
|
|
|
func (t *Tables) Register(models ...interface{}) {
|
|
|
|
for _, model := range models {
|
|
|
|
_ = t.Get(reflect.TypeOf(model).Elem())
|
2021-08-12 21:03:24 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-08-25 15:34:33 +02:00
|
|
|
func (t *Tables) Get(typ reflect.Type) *Table {
|
2021-09-08 20:05:26 +01:00
|
|
|
typ = indirectType(typ)
|
2021-08-12 21:03:24 +02:00
|
|
|
if typ.Kind() != reflect.Struct {
|
|
|
|
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
|
|
|
|
}
|
|
|
|
|
|
|
|
if v, ok := t.tables.Load(typ); ok {
|
2024-11-08 13:51:23 +00:00
|
|
|
return v
|
2021-08-12 21:03:24 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
t.mu.Lock()
|
2024-11-08 13:51:23 +00:00
|
|
|
defer t.mu.Unlock()
|
2021-08-12 21:03:24 +02:00
|
|
|
|
|
|
|
if v, ok := t.tables.Load(typ); ok {
|
2024-11-08 13:51:23 +00:00
|
|
|
return v
|
2021-08-12 21:03:24 +02:00
|
|
|
}
|
|
|
|
|
2024-11-08 13:51:23 +00:00
|
|
|
table := t.InProgress(typ)
|
|
|
|
table.initRelations()
|
2021-11-13 12:29:08 +01:00
|
|
|
|
2021-08-25 15:34:33 +02:00
|
|
|
t.dialect.OnTable(table)
|
|
|
|
for _, field := range table.FieldMap {
|
|
|
|
if field.UserSQLType == "" {
|
|
|
|
field.UserSQLType = field.DiscoveredSQLType
|
|
|
|
}
|
|
|
|
if field.CreateTableSQLType == "" {
|
|
|
|
field.CreateTableSQLType = field.UserSQLType
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-11-08 13:51:23 +00:00
|
|
|
t.tables.Store(typ, table)
|
|
|
|
return table
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *Tables) InProgress(typ reflect.Type) *Table {
|
|
|
|
if table, ok := t.inProgress[typ]; ok {
|
|
|
|
return table
|
|
|
|
}
|
|
|
|
|
|
|
|
table := new(Table)
|
|
|
|
t.inProgress[typ] = table
|
|
|
|
table.init(t.dialect, typ, false)
|
|
|
|
|
2021-08-12 21:03:24 +02:00
|
|
|
return table
|
|
|
|
}
|
|
|
|
|
2024-11-25 15:42:37 +00:00
|
|
|
// ByModel gets the table by its Go name.
|
2021-08-25 15:34:33 +02:00
|
|
|
func (t *Tables) ByModel(name string) *Table {
|
|
|
|
var found *Table
|
2024-11-08 13:51:23 +00:00
|
|
|
t.tables.Range(func(typ reflect.Type, table *Table) bool {
|
|
|
|
if table.TypeName == name {
|
|
|
|
found = table
|
2021-08-25 15:34:33 +02:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
return true
|
|
|
|
})
|
|
|
|
return found
|
2021-08-12 21:03:24 +02:00
|
|
|
}
|
|
|
|
|
2024-11-25 15:42:37 +00:00
|
|
|
// ByName gets the table by its SQL name.
|
2021-08-25 15:34:33 +02:00
|
|
|
func (t *Tables) ByName(name string) *Table {
|
2021-08-12 21:03:24 +02:00
|
|
|
var found *Table
|
2024-11-08 13:51:23 +00:00
|
|
|
t.tables.Range(func(typ reflect.Type, table *Table) bool {
|
|
|
|
if table.Name == name {
|
|
|
|
found = table
|
2021-08-12 21:03:24 +02:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
return true
|
|
|
|
})
|
|
|
|
return found
|
|
|
|
}
|
2024-11-25 15:42:37 +00:00
|
|
|
|
|
|
|
// All returns all registered tables.
|
|
|
|
func (t *Tables) All() []*Table {
|
|
|
|
var found []*Table
|
|
|
|
t.tables.Range(func(typ reflect.Type, table *Table) bool {
|
|
|
|
found = append(found, table)
|
|
|
|
return true
|
|
|
|
})
|
|
|
|
return found
|
|
|
|
}
|