mirror of
https://codeberg.org/superseriousbusiness/gotosocial.git
synced 2024-12-25 10:28:18 +03:00
295 lines
6.8 KiB
Go
295 lines
6.8 KiB
Go
|
package sqlite3
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"math/rand"
|
||
|
"runtime"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/ncruces/go-sqlite3/internal/util"
|
||
|
"github.com/tetratelabs/wazero/api"
|
||
|
)
|
||
|
|
||
|
// Txn is an in-progress database transaction.
|
||
|
//
|
||
|
// https://sqlite.org/lang_transaction.html
|
||
|
type Txn struct {
|
||
|
c *Conn
|
||
|
}
|
||
|
|
||
|
// Begin starts a deferred transaction.
|
||
|
//
|
||
|
// https://sqlite.org/lang_transaction.html
|
||
|
func (c *Conn) Begin() Txn {
|
||
|
// BEGIN even if interrupted.
|
||
|
err := c.txnExecInterrupted(`BEGIN DEFERRED`)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return Txn{c}
|
||
|
}
|
||
|
|
||
|
// BeginImmediate starts an immediate transaction.
|
||
|
//
|
||
|
// https://sqlite.org/lang_transaction.html
|
||
|
func (c *Conn) BeginImmediate() (Txn, error) {
|
||
|
err := c.Exec(`BEGIN IMMEDIATE`)
|
||
|
if err != nil {
|
||
|
return Txn{}, err
|
||
|
}
|
||
|
return Txn{c}, nil
|
||
|
}
|
||
|
|
||
|
// BeginExclusive starts an exclusive transaction.
|
||
|
//
|
||
|
// https://sqlite.org/lang_transaction.html
|
||
|
func (c *Conn) BeginExclusive() (Txn, error) {
|
||
|
err := c.Exec(`BEGIN EXCLUSIVE`)
|
||
|
if err != nil {
|
||
|
return Txn{}, err
|
||
|
}
|
||
|
return Txn{c}, nil
|
||
|
}
|
||
|
|
||
|
// End calls either [Txn.Commit] or [Txn.Rollback]
|
||
|
// depending on whether *error points to a nil or non-nil error.
|
||
|
//
|
||
|
// This is meant to be deferred:
|
||
|
//
|
||
|
// func doWork(db *sqlite3.Conn) (err error) {
|
||
|
// tx := db.Begin()
|
||
|
// defer tx.End(&err)
|
||
|
//
|
||
|
// // ... do work in the transaction
|
||
|
// }
|
||
|
//
|
||
|
// https://sqlite.org/lang_transaction.html
|
||
|
func (tx Txn) End(errp *error) {
|
||
|
recovered := recover()
|
||
|
if recovered != nil {
|
||
|
defer panic(recovered)
|
||
|
}
|
||
|
|
||
|
if *errp == nil && recovered == nil {
|
||
|
// Success path.
|
||
|
if tx.c.GetAutocommit() { // There is nothing to commit.
|
||
|
return
|
||
|
}
|
||
|
*errp = tx.Commit()
|
||
|
if *errp == nil {
|
||
|
return
|
||
|
}
|
||
|
// Fall through to the error path.
|
||
|
}
|
||
|
|
||
|
// Error path.
|
||
|
if tx.c.GetAutocommit() { // There is nothing to rollback.
|
||
|
return
|
||
|
}
|
||
|
err := tx.Rollback()
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Commit commits the transaction.
|
||
|
//
|
||
|
// https://sqlite.org/lang_transaction.html
|
||
|
func (tx Txn) Commit() error {
|
||
|
return tx.c.Exec(`COMMIT`)
|
||
|
}
|
||
|
|
||
|
// Rollback rolls back the transaction,
|
||
|
// even if the connection has been interrupted.
|
||
|
//
|
||
|
// https://sqlite.org/lang_transaction.html
|
||
|
func (tx Txn) Rollback() error {
|
||
|
return tx.c.txnExecInterrupted(`ROLLBACK`)
|
||
|
}
|
||
|
|
||
|
// Savepoint is a marker within a transaction
|
||
|
// that allows for partial rollback.
|
||
|
//
|
||
|
// https://sqlite.org/lang_savepoint.html
|
||
|
type Savepoint struct {
|
||
|
c *Conn
|
||
|
name string
|
||
|
}
|
||
|
|
||
|
// Savepoint establishes a new transaction savepoint.
|
||
|
//
|
||
|
// https://sqlite.org/lang_savepoint.html
|
||
|
func (c *Conn) Savepoint() Savepoint {
|
||
|
// Names can be reused; this makes catching bugs more likely.
|
||
|
name := saveptName() + "_" + strconv.Itoa(int(rand.Int31()))
|
||
|
|
||
|
err := c.txnExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return Savepoint{c: c, name: name}
|
||
|
}
|
||
|
|
||
|
func saveptName() (name string) {
|
||
|
defer func() {
|
||
|
if name == "" {
|
||
|
name = "sqlite3.Savepoint"
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
var pc [8]uintptr
|
||
|
n := runtime.Callers(3, pc[:])
|
||
|
if n <= 0 {
|
||
|
return ""
|
||
|
}
|
||
|
frames := runtime.CallersFrames(pc[:n])
|
||
|
frame, more := frames.Next()
|
||
|
for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
|
||
|
strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
|
||
|
frame, more = frames.Next()
|
||
|
}
|
||
|
return frame.Function
|
||
|
}
|
||
|
|
||
|
// Release releases the savepoint rolling back any changes
|
||
|
// if *error points to a non-nil error.
|
||
|
//
|
||
|
// This is meant to be deferred:
|
||
|
//
|
||
|
// func doWork(db *sqlite3.Conn) (err error) {
|
||
|
// savept := db.Savepoint()
|
||
|
// defer savept.Release(&err)
|
||
|
//
|
||
|
// // ... do work in the transaction
|
||
|
// }
|
||
|
func (s Savepoint) Release(errp *error) {
|
||
|
recovered := recover()
|
||
|
if recovered != nil {
|
||
|
defer panic(recovered)
|
||
|
}
|
||
|
|
||
|
if *errp == nil && recovered == nil {
|
||
|
// Success path.
|
||
|
if s.c.GetAutocommit() { // There is nothing to commit.
|
||
|
return
|
||
|
}
|
||
|
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
|
||
|
if *errp == nil {
|
||
|
return
|
||
|
}
|
||
|
// Fall through to the error path.
|
||
|
}
|
||
|
|
||
|
// Error path.
|
||
|
if s.c.GetAutocommit() { // There is nothing to rollback.
|
||
|
return
|
||
|
}
|
||
|
// ROLLBACK and RELEASE even if interrupted.
|
||
|
err := s.c.txnExecInterrupted(fmt.Sprintf(`
|
||
|
ROLLBACK TO %[1]q;
|
||
|
RELEASE %[1]q;
|
||
|
`, s.name))
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Rollback rolls the transaction back to the savepoint,
|
||
|
// even if the connection has been interrupted.
|
||
|
// Rollback does not release the savepoint.
|
||
|
//
|
||
|
// https://sqlite.org/lang_transaction.html
|
||
|
func (s Savepoint) Rollback() error {
|
||
|
// ROLLBACK even if interrupted.
|
||
|
return s.c.txnExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
|
||
|
}
|
||
|
|
||
|
func (c *Conn) txnExecInterrupted(sql string) error {
|
||
|
err := c.Exec(sql)
|
||
|
if errors.Is(err, INTERRUPT) {
|
||
|
old := c.SetInterrupt(context.Background())
|
||
|
defer c.SetInterrupt(old)
|
||
|
err = c.Exec(sql)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// TxnState starts a deferred transaction.
|
||
|
//
|
||
|
// https://sqlite.org/c3ref/txn_state.html
|
||
|
func (c *Conn) TxnState(schema string) TxnState {
|
||
|
var ptr uint32
|
||
|
if schema != "" {
|
||
|
defer c.arena.mark()()
|
||
|
ptr = c.arena.string(schema)
|
||
|
}
|
||
|
r := c.call("sqlite3_txn_state", uint64(c.handle), uint64(ptr))
|
||
|
return TxnState(r)
|
||
|
}
|
||
|
|
||
|
// CommitHook registers a callback function to be invoked
|
||
|
// whenever a transaction is committed.
|
||
|
// Return true to allow the commit operation to continue normally.
|
||
|
//
|
||
|
// https://sqlite.org/c3ref/commit_hook.html
|
||
|
func (c *Conn) CommitHook(cb func() (ok bool)) {
|
||
|
var enable uint64
|
||
|
if cb != nil {
|
||
|
enable = 1
|
||
|
}
|
||
|
c.call("sqlite3_commit_hook_go", uint64(c.handle), enable)
|
||
|
c.commit = cb
|
||
|
}
|
||
|
|
||
|
// RollbackHook registers a callback function to be invoked
|
||
|
// whenever a transaction is rolled back.
|
||
|
//
|
||
|
// https://sqlite.org/c3ref/commit_hook.html
|
||
|
func (c *Conn) RollbackHook(cb func()) {
|
||
|
var enable uint64
|
||
|
if cb != nil {
|
||
|
enable = 1
|
||
|
}
|
||
|
c.call("sqlite3_rollback_hook_go", uint64(c.handle), enable)
|
||
|
c.rollback = cb
|
||
|
}
|
||
|
|
||
|
// UpdateHook registers a callback function to be invoked
|
||
|
// whenever a row is updated, inserted or deleted in a rowid table.
|
||
|
//
|
||
|
// https://sqlite.org/c3ref/update_hook.html
|
||
|
func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) {
|
||
|
var enable uint64
|
||
|
if cb != nil {
|
||
|
enable = 1
|
||
|
}
|
||
|
c.call("sqlite3_update_hook_go", uint64(c.handle), enable)
|
||
|
c.update = cb
|
||
|
}
|
||
|
|
||
|
func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback uint32) {
|
||
|
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
|
||
|
if !c.commit() {
|
||
|
rollback = 1
|
||
|
}
|
||
|
}
|
||
|
return rollback
|
||
|
}
|
||
|
|
||
|
func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) {
|
||
|
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil {
|
||
|
c.rollback()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zSchema, zTabName uint32, rowid uint64) {
|
||
|
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil {
|
||
|
schema := util.ReadString(mod, zSchema, _MAX_NAME)
|
||
|
table := util.ReadString(mod, zTabName, _MAX_NAME)
|
||
|
c.update(action, schema, table, int64(rowid))
|
||
|
}
|
||
|
}
|