From 7594fc7456cb7f0e708c7dd054222e1b60ff27c3 Mon Sep 17 00:00:00 2001 From: f0x Date: Mon, 3 Oct 2022 18:47:29 +0000 Subject: [PATCH] Merge commit '4d1c7c871bb192bc05ce1cb6fb80773126373111' --- internal/cache/account.go | 5 + internal/db/account.go | 5 + internal/db/bundb/account.go | 140 ++++---- internal/db/bundb/admin.go | 82 +++-- internal/db/bundb/admin_test.go | 39 +++ internal/db/bundb/basic.go | 2 +- internal/db/bundb/bundb.go | 34 +- internal/db/bundb/bundb_test.go | 2 + internal/db/bundb/domain.go | 5 +- internal/db/bundb/emoji.go | 24 +- internal/db/bundb/instance.go | 43 ++- internal/db/bundb/instance_test.go | 83 +++++ internal/db/bundb/media.go | 30 +- internal/db/bundb/mention.go | 2 +- internal/db/bundb/notification.go | 20 +- internal/db/bundb/relationship.go | 298 ++++++++++-------- internal/db/bundb/relationship_test.go | 269 +++++++++++++--- internal/db/bundb/user.go | 8 +- internal/db/bundb/util.go | 16 +- internal/db/params.go | 3 - .../processing/admin/createdomainblock.go | 6 +- .../processing/admin/deletedomainblock.go | 2 +- 22 files changed, 753 insertions(+), 365 deletions(-) create mode 100644 internal/db/bundb/instance_test.go diff --git a/internal/cache/account.go b/internal/cache/account.go index 7e23c3194..12675b6b9 100644 --- a/internal/cache/account.go +++ b/internal/cache/account.go @@ -101,6 +101,11 @@ func (c *AccountCache) Put(account *gtsmodel.Account) { c.cache.Set(account.ID, copyAccount(account)) } +// Invalidate removes (invalidates) one account from the cache by its ID. +func (c *AccountCache) Invalidate(id string) { + c.cache.Invalidate(id) +} + // copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. // due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) // this should be a relatively cheap process diff --git a/internal/db/account.go b/internal/db/account.go index 351d6d01c..ae5eea7c6 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -48,6 +48,11 @@ type Account interface { // UpdateAccount updates one account by ID. UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) + // DeleteAccount deletes one account from the database by its ID. + // DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the + // account as suspended instead, rather than deleting from the db entirely. + DeleteAccount(ctx context.Context, id string) Error + // GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username. GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error) diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 074804690..7ed443f61 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -21,7 +21,6 @@ package bundb import ( "context" "errors" - "fmt" "strings" "time" @@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac return a.cache.GetByID(id) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) }, ) } @@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. return a.cache.GetByURI(uri) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) }, ) } @@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. return a.cache.GetByURL(url) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) }, ) } @@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str q := a.newAccountQ(account) if domain != "" { - q = q.Where("account.username = ?", username) - q = q.Where("account.domain = ?", domain) + q = q.Where("? = ?", bun.Ident("account.username"), username) + q = q.Where("? = ?", bun.Ident("account.domain"), domain) } else { - q = q.Where("account.username = ?", strings.ToLower(username)) - q = q.Where("account.domain IS NULL") + q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) + q = q.Where("? IS NULL", bun.Ident("account.domain")) } return q.Scan(ctx) @@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo return a.cache.GetByPubkeyID(id) }, func(account *gtsmodel.Account) error { - return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx) + return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) }, ) } @@ -170,8 +169,8 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account // create links between this account and any emojis it uses // first clear out any old emoji links if _, err := tx.NewDelete(). - Model(&[]*gtsmodel.AccountToEmoji{}). - Where("account_id = ?", account.ID). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). Exec(ctx); err != nil { return err } @@ -197,6 +196,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account return account, nil } +func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { + if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { + // clear out any emoji links + if _, err := tx. + NewDelete(). + TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). + Where("? = ?", bun.Ident("account_to_emoji.account_id"), id). + Exec(ctx); err != nil { + return err + } + + // delete the account + _, err := tx. + NewUpdate(). + Model(>smodel.Account{ID: id}). + WherePK(). + Exec(ctx) + return err + }); err != nil { + return a.conn.ProcessError(err) + } + + a.cache.Invalidate(id) + return nil +} + func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { account := new(gtsmodel.Account) @@ -204,11 +229,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts if domain != "" { q = q. - Where("account.username = ?", domain). - Where("account.domain = ?", domain) + Where("? = ?", bun.Ident("account.username"), domain). + Where("? = ?", bun.Ident("account.domain"), domain) } else { q = q. - Where("account.username = ?", config.GetHost()). + Where("? = ?", bun.Ident("account.username"), config.GetHost()). WhereGroup(" AND ", whereEmptyOrNull("domain")) } @@ -224,10 +249,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) q := a.conn. NewSelect(). Model(status). - Order("id DESC"). - Limit(1). - Where("account_id = ?", accountID). - Column("created_at") + Column("status.created_at"). + Where("? = ?", bun.Ident("status.account_id"), accountID). + Order("status.id DESC"). + Limit(1) if err := q.Scan(ctx); err != nil { return time.Time{}, a.conn.ProcessError(err) @@ -240,12 +265,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen return errors.New("one media attachment cannot be both header and avatar") } - var headerOrAVI string + var column bun.Ident switch { case *mediaAttachment.Avatar: - headerOrAVI = "avatar" + column = bun.Ident("account.avatar_media_attachment_id") case *mediaAttachment.Header: - headerOrAVI = "header" + column = bun.Ident("account.header_media_attachment_id") default: return errors.New("given media attachment was neither a header nor an avatar") } @@ -257,11 +282,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen Exec(ctx); err != nil { return a.conn.ProcessError(err) } + if _, err := a.conn. NewUpdate(). - Model(>smodel.Account{}). - Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). - Where("id = ?", accountID). + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Set("? = ?", column, mediaAttachment.ID). + Where("? = ?", bun.Ident("account.id"), accountID). Exec(ctx); err != nil { return a.conn.ProcessError(err) } @@ -284,7 +310,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g if err := a.conn. NewSelect(). Model(faves). - Where("account_id = ?", accountID). + Where("? = ?", bun.Ident("status_fave.account_id"), accountID). Scan(ctx); err != nil { return nil, a.conn.ProcessError(err) } @@ -295,8 +321,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { return a.conn. NewSelect(). - Model(>smodel.Status{}). - Where("account_id = ?", accountID). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Where("? = ?", bun.Ident("status.account_id"), accountID). Count(ctx) } @@ -305,12 +331,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li q := a.conn. NewSelect(). - Table("statuses"). - Column("id"). - Order("id DESC") + Model(>smodel.Status{}). + Column("status.id"). + Order("status.id DESC") if accountID != "" { - q = q.Where("account_id = ?", accountID) + q = q.Where("? = ?", bun.Ident("status.account_id"), accountID) } if limit != 0 { @@ -321,27 +347,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li // include self-replies (threads) whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { return q. - WhereOr("in_reply_to_account_id = ?", accountID). - WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri")) + WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). + WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri")) } q = q.WhereGroup(" AND ", whereGroup) } if excludeReblogs { - q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")) + q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")) } if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("status.id"), maxID) } if minID != "" { - q = q.Where("id > ?", minID) + q = q.Where("? > ?", bun.Ident("status.id"), minID) } if pinnedOnly { - q = q.Where("pinned = ?", true) + q = q.Where("? = ?", bun.Ident("status.pinned"), true) } if mediaOnly { @@ -352,15 +378,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li switch a.conn.Dialect().Name() { case dialect.PG: return q. - Where("? IS NOT NULL", bun.Ident("attachments")). - Where("? != '{}'", bun.Ident("attachments")) + Where("? IS NOT NULL", bun.Ident("status.attachments")). + Where("? != '{}'", bun.Ident("status.attachments")) case dialect.SQLite: return q. - Where("? IS NOT NULL", bun.Ident("attachments")). - Where("? != ''", bun.Ident("attachments")). - Where("? != 'null'", bun.Ident("attachments")). - Where("? != '{}'", bun.Ident("attachments")). - Where("? != '[]'", bun.Ident("attachments")) + Where("? IS NOT NULL", bun.Ident("status.attachments")). + Where("? != ''", bun.Ident("status.attachments")). + Where("? != 'null'", bun.Ident("status.attachments")). + Where("? != '{}'", bun.Ident("status.attachments")). + Where("? != '[]'", bun.Ident("status.attachments")) default: log.Panic("db dialect was neither pg nor sqlite") return q @@ -369,7 +395,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } if publicOnly { - q = q.Where("visibility = ?", gtsmodel.VisibilityPublic) + q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic) } if err := q.Scan(ctx, &statusIDs); err != nil { @@ -384,19 +410,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, q := a.conn. NewSelect(). - Table("statuses"). - Column("id"). - Where("account_id = ?", accountID). - WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")). - WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")). - Where("visibility = ?", gtsmodel.VisibilityPublic). - Where("federated = ?", true) + Model(>smodel.Status{}). + Column("status.id"). + Where("? = ?", bun.Ident("status.account_id"), accountID). + WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). + WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). + Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). + Where("? = ?", bun.Ident("status.federated"), true) if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("status.id"), maxID) } - q = q.Limit(limit).Order("id DESC") + q = q.Limit(limit).Order("status.id DESC") if err := q.Scan(ctx, &statusIDs); err != nil { return nil, a.conn.ProcessError(err) @@ -411,16 +437,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI fq := a.conn. NewSelect(). Model(&blocks). - Where("block.account_id = ?", accountID). + Where("? = ?", bun.Ident("block.account_id"), accountID). Relation("TargetAccount"). Order("block.id DESC") if maxID != "" { - fq = fq.Where("block.id < ?", maxID) + fq = fq.Where("? < ?", bun.Ident("block.id"), maxID) } if sinceID != "" { - fq = fq.Where("block.id > ?", sinceID) + fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID) } if limit > 0 { diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 9fa78eca0..47551ec08 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -22,7 +22,6 @@ import ( "context" "crypto/rand" "crypto/rsa" - "database/sql" "fmt" "net" "net/mail" @@ -37,21 +36,26 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/uris" + "github.com/uptrace/bun" "golang.org/x/crypto/bcrypt" ) +// generate RSA keys of this length +const rsaKeyBits = 2048 + type adminDB struct { - conn *DBConn - userCache *cache.UserCache + conn *DBConn + userCache *cache.UserCache + accountCache *cache.AccountCache } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { q := a.conn. NewSelect(). - Model(>smodel.Account{}). - Where("username = ?", username). - Where("domain = ?", nil) - + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.username"), username). + Where("? IS NULL", bun.Ident("account.domain")) return a.conn.NotExists(ctx, q) } @@ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ // check if the email domain is blocked - if err := a.conn. + emailDomainBlockedQ := a.conn. NewSelect(). - Model(>smodel.EmailDomainBlock{}). - Where("domain = ?", domain). - Scan(ctx); err == nil { - // fail because we found something + TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). + Column("email_domain_block.id"). + Where("? = ?", bun.Ident("email_domain_block.domain"), domain) + emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ) + if err != nil { + return false, err + } + if emailDomainBlocked { return false, fmt.Errorf("email domain %s is blocked", domain) - } else if err != sql.ErrNoRows { - return false, a.conn.ProcessError(err) } // check if this email is associated with a user already q := a.conn. NewSelect(). - Model(>smodel.User{}). - Where("email = ?", email). - WhereOr("unconfirmed_email = ?", email) - + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("user.id"). + Where("? = ?", bun.Ident("user.email"), email). + WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) return a.conn.NotExists(ctx, q) } func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) + key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) if err != nil { log.Errorf("error creating new rsa key: %s", err) return nil, err @@ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, // if something went wrong while creating a user, we might already have an account, so check here first... acct := >smodel.Account{} - q := a.conn.NewSelect(). + if err := a.conn. + NewSelect(). Model(acct). - Where("username = ?", username). - WhereGroup(" AND ", whereEmptyOrNull("domain")) + Where("? = ?", bun.Ident("account.username"), username). + WhereGroup(" AND ", whereEmptyOrNull("account.domain")). + Scan(ctx); err != nil { + err = a.conn.ProcessError(err) + if err != db.ErrNoEntries { + log.Errorf("error checking for existing account: %s", err) + return nil, err + } - if err := q.Scan(ctx); err != nil { - // we just don't have an account yet so create one before we proceed + // if we have db.ErrNoEntries, we just don't have an + // account yet so create one before we proceed accountURIs := uris.GenerateURIsForAccount(username) accountID, err := id.NewRandomULID() if err != nil { @@ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, FeaturedCollectionURI: accountURIs.CollectionURI, } + // insert the new account! if _, err = a.conn. NewInsert(). Model(acct). Exec(ctx); err != nil { return nil, a.conn.ProcessError(err) } + a.accountCache.Put(acct) } + // we either created or already had an account by now, + // so proceed with creating a user for that account + pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("error hashing password: %s", err) @@ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, u.Moderator = &moderator } + // insert the user! if _, err = a.conn. NewInsert(). Model(u). @@ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { q := a.conn. NewSelect(). - Model(>smodel.Account{}). - Where("username = ?", username). - WhereGroup(" AND ", whereEmptyOrNull("domain")) + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? = ?", bun.Ident("account.username"), username). + WhereGroup(" AND ", whereEmptyOrNull("account.domain")) exists, err := a.conn.Exists(ctx, q) if err != nil { @@ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { return nil } - key, err := rsa.GenerateKey(rand.Reader, 2048) + key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) if err != nil { log.Errorf("error creating new rsa key: %s", err) return err @@ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { return a.conn.ProcessError(err) } + a.accountCache.Put(acct) log.Infof("instance account %s CREATED with id %s", username, acct.ID) return nil } @@ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { // check if instance entry already exists q := a.conn. NewSelect(). - Model(>smodel.Instance{}). - Where("domain = ?", host) + Column("instance.id"). + TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). + Where("? = ?", bun.Ident("instance.domain"), host) exists, err := a.conn.Exists(ctx, q) if err != nil { diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go index 22041087a..60c450d23 100644 --- a/internal/db/bundb/admin_test.go +++ b/internal/db/bundb/admin_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/stretchr/testify/suite" + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -30,6 +31,44 @@ type AdminTestSuite struct { BunDBStandardTestSuite } +func (suite *AdminTestSuite) TestIsUsernameAvailableNo() { + available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork") + suite.NoError(err) + suite.False(available) +} + +func (suite *AdminTestSuite) TestIsUsernameAvailableYes() { + available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different") + suite.NoError(err) + suite.True(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableNo() { + available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org") + suite.NoError(err) + suite.False(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableYes() { + available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com") + suite.NoError(err) + suite.True(available) +} + +func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() { + if err := suite.db.Put(context.Background(), >smodel.EmailDomainBlock{ + ID: "01GEEV2R2YC5GRSN96761YJE47", + Domain: "somewhere.com", + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + }); err != nil { + suite.FailNow(err.Error()) + } + + available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com") + suite.EqualError(err, "email domain somewhere.com is blocked") + suite.False(available) +} + func (suite *AdminTestSuite) TestCreateInstanceAccount() { // we need to take an empty db for this... testrig.StandardDBTeardown(suite.db) diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index cd80c9330..722bb0e3c 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, updateWhere(q, where) - q = q.Set("? = ?", bun.Safe(key), value) + q = q.Set("? = ?", bun.Ident(key), value) _, err := q.Exec(ctx) return b.conn.ProcessError(err) diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 70a44d4c1..02522e6f7 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { return nil, fmt.Errorf("db migration error: %s", err) } - // Create DB structs that require ptrs to each other - accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()} - status := &statusDB{conn: conn, cache: cache.NewStatusCache()} - emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()} - timeline := &timelineDB{conn: conn} - - // Setup DB cross-referencing - accounts.status = status - status.accounts = accounts - timeline.status = status + // Prepare caches required by more than one struct + userCache := cache.NewUserCache() + accountCache := cache.NewAccountCache() + // Prepare other caches // Prepare mentions cache // TODO: move into internal/cache mentionCache := grufcache.New[string, *gtsmodel.Mention]() @@ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { notifCache.SetTTL(time.Minute*5, false) notifCache.Start(time.Second * 10) - // Prepare other caches - blockCache := cache.NewDomainBlockCache() - userCache := cache.NewUserCache() + // Create DB structs that require ptrs to each other + accounts := &accountDB{conn: conn, cache: accountCache} + status := &statusDB{conn: conn, cache: cache.NewStatusCache()} + emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()} + timeline := &timelineDB{conn: conn} + + // Setup DB cross-referencing + accounts.status = status + status.accounts = accounts + timeline.status = status ps := &DBService{ Account: accounts, Admin: &adminDB{ - conn: conn, - userCache: userCache, + conn: conn, + userCache: userCache, + accountCache: accountCache, }, Basic: &basicDB{ conn: conn, }, Domain: &domainDB{ conn: conn, - cache: blockCache, + cache: cache.NewDomainBlockCache(), }, Emoji: emoji, Instance: &instanceDB{ diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index 581573056..2af6cf122 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -40,6 +40,7 @@ type BunDBStandardTestSuite struct { testStatuses map[string]*gtsmodel.Status testTags map[string]*gtsmodel.Tag testMentions map[string]*gtsmodel.Mention + testFollows map[string]*gtsmodel.Follow } func (suite *BunDBStandardTestSuite) SetupSuite() { @@ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() { suite.testStatuses = testrig.NewTestStatuses() suite.testTags = testrig.NewTestTags() suite.testMentions = testrig.NewTestMentions() + suite.testFollows = testrig.NewTestFollows() } func (suite *BunDBStandardTestSuite) SetupTest() { diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 5d262c676..bcfdc4d10 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -28,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" "golang.org/x/net/idna" ) @@ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel q := d.conn. NewSelect(). Model(block). - Where("domain = ?", domain). + Where("? = ?", bun.Ident("domain_block.domain"), domain). Limit(1) // Query database for domain block @@ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro // Attempt to delete domain block if _, err := d.conn.NewDelete(). Model((*gtsmodel.DomainBlock)(nil)). - Where("domain = ?", domain). + Where("? = ?", bun.Ident("domain_block.domain"), domain). Exec(ctx); err != nil { return d.conn.ProcessError(err) } diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index 758da0feb..e781e2f00 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er q := e.conn. NewSelect(). - Table("emojis"). - Column("id"). - Where("visible_in_picker = true"). - Where("disabled = false"). - Where("domain IS NULL"). - Order("shortcode ASC") + TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")). + Column("emoji.id"). + Where("? = ?", bun.Ident("emoji.visible_in_picker"), true). + Where("? = ?", bun.Ident("emoji.disabled"), false). + Where("? IS NULL", bun.Ident("emoji.domain")). + Order("emoji.shortcode ASC") if err := q.Scan(ctx, &emojiIDs); err != nil { return nil, e.conn.ProcessError(err) @@ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, return e.cache.GetByID(id) }, func(emoji *gtsmodel.Emoji) error { - return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx) + return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx) }, ) } @@ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj return e.cache.GetByURI(uri) }, func(emoji *gtsmodel.Emoji) error { - return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx) + return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx) }, ) } @@ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin q := e.newEmojiQ(emoji) if domain != "" { - q = q.Where("emoji.shortcode = ?", shortcode) - q = q.Where("emoji.domain = ?", domain) + q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode) + q = q.Where("? = ?", bun.Ident("emoji.domain"), domain) } else { - q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode)) - q = q.Where("emoji.domain IS NULL") + q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode)) + q = q.Where("? IS NULL", bun.Ident("emoji.domain")) } return q.Scan(ctx) diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index fb6454e2f..604461708 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -24,7 +24,6 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" ) @@ -35,15 +34,16 @@ type instanceDB struct { func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { q := i.conn. NewSelect(). - Model(&[]*gtsmodel.Account{}). - Where("username != ?", domain). - Where("? IS NULL", bun.Ident("suspended_at")) + TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). + Column("account.id"). + Where("? != ?", bun.Ident("account.username"), domain). + Where("? IS NULL", bun.Ident("account.suspended_at")) - if domain == config.GetHost() { + if domain == config.GetHost() || domain == config.GetAccountDomain() { // if the domain is *this* domain, just count where the domain field is null - q = q.WhereGroup(" AND ", whereEmptyOrNull("domain")) + q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain")) } else { - q = q.Where("domain = ?", domain) + q = q.Where("? = ?", bun.Ident("account.domain"), domain) } count, err := q.Count(ctx) @@ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { q := i.conn. NewSelect(). - Model(&[]*gtsmodel.Status{}) + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")) - if domain == config.GetHost() { + if domain == config.GetHost() || domain == config.GetAccountDomain() { // if the domain is *this* domain, just count where local is true - q = q.Where("local = ?", true) + q = q.Where("? = ?", bun.Ident("status.local"), true) } else { // join on the domain of the account - q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). - Where("account.domain = ?", domain) + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")). + Where("? = ?", bun.Ident("account.domain"), domain) } count, err := q.Count(ctx) @@ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) ( func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { q := i.conn. NewSelect(). - Model(&[]*gtsmodel.Instance{}) + TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")) if domain == config.GetHost() { // if the domain is *this* domain, just count other instances it knows about // exclude domains that are blocked q = q. - Where("domain != ?", domain). - Where("? IS NULL", bun.Ident("suspended_at")) + Where("? != ?", bun.Ident("instance.domain"), domain). + Where("? IS NULL", bun.Ident("instance.suspended_at")) } else { // TODO: implement federated domain counting properly for remote domains return 0, nil @@ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool q := i.conn. NewSelect(). Model(&instances). - Where("domain != ?", config.GetHost()) + Where("? != ?", bun.Ident("instance.domain"), config.GetHost()) if !includeSuspended { - q = q.Where("? IS NULL", bun.Ident("suspended_at")) + q = q.Where("? IS NULL", bun.Ident("instance.suspended_at")) } if err := q.Scan(ctx); err != nil { @@ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool } func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { - log.Debug("GetAccountsForInstance") - accounts := []*gtsmodel.Account{} q := i.conn.NewSelect(). Model(&accounts). - Where("domain = ?", domain). - Order("id DESC") + Where("? = ?", bun.Ident("account.domain"), domain). + Order("account.id DESC") if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("account.id"), maxID) } if limit > 0 { diff --git a/internal/db/bundb/instance_test.go b/internal/db/bundb/instance_test.go new file mode 100644 index 000000000..50d118888 --- /dev/null +++ b/internal/db/bundb/instance_test.go @@ -0,0 +1,83 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" +) + +type InstanceTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *InstanceTestSuite) TestCountInstanceUsers() { + count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost()) + suite.NoError(err) + suite.Equal(4, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() { + count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io") + suite.NoError(err) + suite.Equal(1, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceStatuses() { + count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost()) + suite.NoError(err) + suite.Equal(16, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() { + count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io") + suite.NoError(err) + suite.Equal(1, count) +} + +func (suite *InstanceTestSuite) TestCountInstanceDomains() { + count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost()) + suite.NoError(err) + suite.Equal(2, count) +} + +func (suite *InstanceTestSuite) TestGetInstancePeers() { + peers, err := suite.db.GetInstancePeers(context.Background(), false) + suite.NoError(err) + suite.Len(peers, 2) +} + +func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() { + peers, err := suite.db.GetInstancePeers(context.Background(), true) + suite.NoError(err) + suite.Len(peers, 2) +} + +func (suite *InstanceTestSuite) TestGetInstanceAccounts() { + accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10) + suite.NoError(err) + suite.Len(accounts, 1) +} + +func TestInstanceTestSuite(t *testing.T) { + suite.Run(t, new(InstanceTestSuite)) +} diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index 71433b901..39e0ad0e3 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M attachment := >smodel.MediaAttachment{} q := m.newMediaQ(attachment). - Where("media_attachment.id = ?", id) + Where("? = ?", bun.Ident("media_attachment.id"), id) if err := q.Scan(ctx); err != nil { return nil, m.conn.ProcessError(err) @@ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l q := m.conn. NewSelect(). Model(&attachments). - Where("media_attachment.cached = true"). - Where("media_attachment.avatar = false"). - Where("media_attachment.header = false"). - Where("media_attachment.created_at < ?", olderThan). + Where("? = ?", bun.Ident("media_attachment.cached"), true). + Where("? = ?", bun.Ident("media_attachment.avatar"), false). + Where("? = ?", bun.Ident("media_attachment.header"), false). + Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan). WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")). Order("media_attachment.created_at DESC") @@ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit q := m.newMediaQ(&attachments). WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery { return innerQ. - WhereOr("media_attachment.avatar = true"). - WhereOr("media_attachment.header = true") + WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true). + WhereOr("? = ?", bun.Ident("media_attachment.header"), true) }). Order("media_attachment.id DESC") if maxID != "" { - q = q.Where("media_attachment.id < ?", maxID) + q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID) } if limit != 0 { @@ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim attachments := []*gtsmodel.MediaAttachment{} q := m.newMediaQ(&attachments). - Where("media_attachment.cached = true"). - Where("media_attachment.avatar = false"). - Where("media_attachment.header = false"). - Where("media_attachment.created_at < ?", olderThan). - Where("media_attachment.remote_url IS NULL"). - Where("media_attachment.status_id IS NULL") + Where("? = ?", bun.Ident("media_attachment.cached"), true). + Where("? = ?", bun.Ident("media_attachment.avatar"), false). + Where("? = ?", bun.Ident("media_attachment.header"), false). + Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan). + Where("? IS NULL", bun.Ident("media_attachment.remote_url")). + Where("? IS NULL", bun.Ident("media_attachment.status_id")) if maxID != "" { - q = q.Where("media_attachment.id < ?", maxID) + q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID) } if limit != 0 { diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index e2c83ef3f..355078021 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment mention := gtsmodel.Mention{} q := m.newMentionQ(&mention). - Where("mention.id = ?", id) + Where("? = ?", bun.Ident("mention.id"), id) if err := q.Scan(ctx); err != nil { return nil, m.conn.ProcessError(err) diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 32523ca24..2c1b4848e 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -25,6 +25,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/uptrace/bun" ) type notificationDB struct { @@ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, q := n.conn. NewSelect(). - Table("notifications"). - Column("id") + TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). + Column("notification.id") if maxID != "" { - q = q.Where("id < ?", maxID) + q = q.Where("? < ?", bun.Ident("notification.id"), maxID) } if sinceID != "" { - q = q.Where("id > ?", sinceID) + q = q.Where("? > ?", bun.Ident("notification.id"), sinceID) } for _, excludeType := range excludeTypes { - q = q.Where("notification_type != ?", excludeType) + q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType) } q = q. - Where("target_account_id = ?", accountID). - Order("id DESC") + Where("? = ?", bun.Ident("notification.target_account_id"), accountID). + Order("notification.id DESC") if limit != 0 { q = q.Limit(limit) @@ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error { if _, err := n.conn. NewDelete(). - Table("notifications"). - Where("target_account_id = ?", accountID). + TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). + Where("? = ?", bun.Ident("notification.target_account_id"), accountID). Exec(ctx); err != nil { return n.conn.ProcessError(err) } n.cache.Clear() - return nil } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index ba72a053a..470599b52 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery { func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { q := r.conn. NewSelect(). - Model(>smodel.Block{}). - ExcludeColumn("id", "created_at", "updated_at", "uri"). - Limit(1) + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id") if eitherDirection { q = q. WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { return inner. - Where("account_id = ?", account1). - Where("target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) }). WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { return inner. - Where("account_id = ?", account2). - Where("target_account_id = ?", account1) + Where("? = ?", bun.Ident("block.account_id"), account2). + Where("? = ?", bun.Ident("block.target_account_id"), account1) }) } else { q = q. - Where("account_id = ?", account1). - Where("target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) } return r.conn.Exists(ctx, q) @@ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 block := >smodel.Block{} q := r.newBlockQ(block). - Where("block.account_id = ?", account1). - Where("block.target_account_id = ?", account2) + Where("? = ?", bun.Ident("block.account_id"), account1). + Where("? = ?", bun.Ident("block.target_account_id"), account2) if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) @@ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount if err := r.conn. NewSelect(). Model(follow). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). + Column("follow.show_reblogs", "follow.notify"). + Where("? = ?", bun.Ident("follow.account_id"), requestingAccount). + Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount). Limit(1). Scan(ctx); err != nil { - if err != sql.ErrNoRows { - // a proper error - return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) + if err := r.conn.ProcessError(err); err != db.ErrNoEntries { + return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err) } // no follow exists so these are all false rel.Following = false @@ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount } // check if the target account follows the requesting account - count, err := r.conn. + followedByQ := r.conn. NewSelect(). - Model(>smodel.Follow{}). - Where("account_id = ?", targetAccount). - Where("target_account_id = ?", requestingAccount). - Limit(1). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). + Column("follow.id"). + Where("? = ?", bun.Ident("follow.account_id"), targetAccount). + Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount) + followedBy, err := r.conn.Exists(ctx, followedByQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err) } - rel.FollowedBy = count > 0 - - // check if the requesting account blocks the target account - count, err = r.conn.NewSelect(). - Model(>smodel.Block{}). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). - Limit(1). - Count(ctx) - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) - } - rel.Blocking = count > 0 - - // check if the target account blocks the requesting account - count, err = r.conn. - NewSelect(). - Model(>smodel.Block{}). - Where("account_id = ?", targetAccount). - Where("target_account_id = ?", requestingAccount). - Limit(1). - Count(ctx) - if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) - } - rel.BlockedBy = count > 0 + rel.FollowedBy = followedBy // check if there's a pending following request from requesting account to target account - count, err = r.conn. + requestedQ := r.conn. NewSelect(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", requestingAccount). - Where("target_account_id = ?", targetAccount). - Limit(1). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Column("follow_request.id"). + Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount) + requested, err := r.conn.Exists(ctx, requestedQ) if err != nil { - return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) + return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err) } - rel.Requested = count > 0 + rel.Requested = requested + + // check if the requesting account is blocking the target account + blockingQ := r.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id"). + Where("? = ?", bun.Ident("block.account_id"), requestingAccount). + Where("? = ?", bun.Ident("block.target_account_id"), targetAccount) + blocking, err := r.conn.Exists(ctx, blockingQ) + if err != nil { + return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err) + } + rel.Blocking = blocking + + // check if the requesting account is blocked by the target account + blockedByQ := r.conn. + NewSelect(). + TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). + Column("block.id"). + Where("? = ?", bun.Ident("block.account_id"), targetAccount). + Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount) + blockedBy, err := r.conn.Exists(ctx, blockedByQ) + if err != nil { + return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err) + } + rel.BlockedBy = blockedBy return rel, nil } @@ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode q := r.conn. NewSelect(). - Model(>smodel.Follow{}). - Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID). - Limit(1) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). + Column("follow.id"). + Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID). + Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID) return r.conn.Exists(ctx, q) } @@ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g q := r.conn. NewSelect(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", sourceAccount.ID). - Where("target_account_id = ?", targetAccount.ID) + TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). + Column("follow_request.id"). + Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID) return r.conn.Exists(ctx, q) } @@ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod } func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { - // make sure the original follow request exists - fr := >smodel.FollowRequest{} - if err := r.conn. - NewSelect(). - Model(fr). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Scan(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } - - // create a new follow to 'replace' the request with - follow := >smodel.Follow{ - ID: fr.ID, - AccountID: originAccountID, - TargetAccountID: targetAccountID, - URI: fr.URI, - } - - // if the follow already exists, just update the URI -- we don't need to do anything else - if _, err := r.conn. - NewInsert(). - Model(follow). - On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI). - Exec(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } - - // now remove the follow request - if _, err := r.conn. - NewDelete(). - Model(>smodel.FollowRequest{}). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Exec(ctx); err != nil { + var follow *gtsmodel.Follow + + if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { + // get original follow request + followRequest := >smodel.FollowRequest{} + if err := tx. + NewSelect(). + Model(followRequest). + Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). + Scan(ctx); err != nil { + return err + } + + // create a new follow to 'replace' the request with + follow = >smodel.Follow{ + ID: followRequest.ID, + AccountID: originAccountID, + TargetAccountID: targetAccountID, + URI: followRequest.URI, + } + + // if the follow already exists, just update the URI -- we don't need to do anything else + if _, err := tx. + NewInsert(). + Model(follow). + On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). + Exec(ctx); err != nil { + return err + } + + // now remove the follow request + if _, err := tx. + NewDelete(). + Model(followRequest). + WherePK(). + Exec(ctx); err != nil { + return err + } + + return nil + }); err != nil { return nil, r.conn.ProcessError(err) } + // return the new follow return follow, nil } func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { - // first get the follow request out of the database - fr := >smodel.FollowRequest{} - if err := r.conn. - NewSelect(). - Model(fr). - Where("account_id = ?", originAccountID). - Where("target_account_id = ?", targetAccountID). - Scan(ctx); err != nil { - return nil, r.conn.ProcessError(err) - } + followRequest := >smodel.FollowRequest{} - // now delete it from the database by ID - if _, err := r.conn. - NewDelete(). - Model(>smodel.FollowRequest{ID: fr.ID}). - WherePK(). - Exec(ctx); err != nil { + if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { + // get original follow request + if err := tx. + NewSelect(). + Model(followRequest). + Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). + Scan(ctx); err != nil { + return err + } + + // now delete it from the database by ID + if _, err := tx. + NewDelete(). + Model(followRequest). + WherePK(). + Exec(ctx); err != nil { + return err + } + + return nil + }); err != nil { return nil, r.conn.ProcessError(err) } // return the deleted follow request - return fr, nil + return followRequest, nil } func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { followRequests := []*gtsmodel.FollowRequest{} q := r.newFollowQ(&followRequests). - Where("target_account_id = ?", accountID). + Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID). Order("follow_request.updated_at DESC") if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) } + return followRequests, nil } @@ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string follows := []*gtsmodel.Follow{} q := r.newFollowQ(&follows). - Where("account_id = ?", accountID). + Where("? = ?", bun.Ident("follow.account_id"), accountID). Order("follow.updated_at DESC") if err := q.Scan(ctx); err != nil { return nil, r.conn.ProcessError(err) } + return follows, nil } func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { - return r.conn. + q := r.conn. NewSelect(). - Model(&[]*gtsmodel.Follow{}). - Where("account_id = ?", accountID). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) + + if localOnly { + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) + } else { + q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID) + } + + return q.Count(ctx) } func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { @@ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str Order("follow.updated_at DESC") if localOnly { - q = q.ColumnExpr("follow.*"). - Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)"). - Where("follow.target_account_id = ?", accountID). - WhereGroup(" AND ", whereEmptyOrNull("a.domain")) + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.target_account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) } else { - q = q.Where("target_account_id = ?", accountID) + q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) } err := q.Scan(ctx) @@ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str } func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { - return r.conn. + q := r.conn. NewSelect(). - Model(&[]*gtsmodel.Follow{}). - Where("target_account_id = ?", accountID). - Count(ctx) + TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) + + if localOnly { + q = q. + Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). + Where("? = ?", bun.Ident("follow.target_account_id"), accountID). + Where("? IS NULL", bun.Ident("account.domain")) + } else { + q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) + } + + return q.Count(ctx) } diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index 34fe85a57..3df16e2f3 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -20,7 +20,6 @@ package bundb_test import ( "context" - "errors" "testing" "github.com/stretchr/testify/suite" @@ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() { suite.False(blocked) // have account1 block account2 - suite.db.Put(ctx, >smodel.Block{ + if err := suite.db.Put(ctx, >smodel.Block{ ID: "01G202BCSXXJZ70BHB5KCAHH8C", URI: "http://localhost:8080/some_block_uri_1", AccountID: account1, TargetAccountID: account2, - }) + }); err != nil { + suite.FailNow(err.Error()) + } // account 1 now blocks account 2 blocked, err = suite.db.IsBlocked(ctx, account1, account2, false) @@ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() { } func (suite *RelationshipTestSuite) TestGetBlock() { - suite.Suite.T().Skip("TODO: implement") + ctx := context.Background() + + account1 := suite.testAccounts["local_account_1"].ID + account2 := suite.testAccounts["local_account_2"].ID + + if err := suite.db.Put(ctx, >smodel.Block{ + ID: "01G202BCSXXJZ70BHB5KCAHH8C", + URI: "http://localhost:8080/some_block_uri_1", + AccountID: account1, + TargetAccountID: account2, + }); err != nil { + suite.FailNow(err.Error()) + } + + block, err := suite.db.GetBlock(ctx, account1, account2) + suite.NoError(err) + suite.NotNil(block) + suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID) } func (suite *RelationshipTestSuite) TestGetRelationship() { - suite.Suite.T().Skip("TODO: implement") + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["admin_account"] + + relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(relationship) + + suite.True(relationship.Following) + suite.True(relationship.ShowingReblogs) + suite.False(relationship.Notifying) + suite.True(relationship.FollowedBy) + suite.False(relationship.Blocking) + suite.False(relationship.BlockedBy) + suite.False(relationship.Muting) + suite.False(relationship.MutingNotifications) + suite.False(relationship.Requested) + suite.False(relationship.DomainBlocking) + suite.False(relationship.Endorsed) + suite.Empty(relationship.Note) } -func (suite *RelationshipTestSuite) TestIsFollowing() { - suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestIsFollowingYes() { + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["admin_account"] + isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) + suite.NoError(err) + suite.True(isFollowing) +} + +func (suite *RelationshipTestSuite) TestIsFollowingNo() { + requestingAccount := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) + suite.NoError(err) + suite.False(isFollowing) } func (suite *RelationshipTestSuite) TestIsMutualFollowing() { - suite.Suite.T().Skip("TODO: implement") + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["admin_account"] + isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) + suite.NoError(err) + suite.True(isMutualFollowing) } -func (suite *RelationshipTestSuite) AcceptFollowRequest() { - for _, account := range suite.testAccounts { - _, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID") - if err != nil && !errors.Is(err, db.ErrNoEntries) { - suite.Suite.Fail("error accepting follow request: %v", err) - } +func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() { + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["local_account_2"] + isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) + suite.NoError(err) + suite.True(isMutualFollowing) +} + +func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + + followRequest := >smodel.FollowRequest{ + ID: "01GEF753FWHCHRDWR0QEHBXM8W", + URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", + AccountID: account.ID, + TargetAccountID: targetAccount.ID, } -} -func (suite *RelationshipTestSuite) GetAccountFollowRequests() { - suite.Suite.T().Skip("TODO: implement") -} - -func (suite *RelationshipTestSuite) GetAccountFollows() { - suite.Suite.T().Skip("TODO: implement") -} - -func (suite *RelationshipTestSuite) CountAccountFollows() { - suite.Suite.T().Skip("TODO: implement") -} - -func (suite *RelationshipTestSuite) GetAccountFollowedBy() { - // TODO: more comprehensive tests here - - for _, account := range suite.testAccounts { - var err error - - _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) - if err != nil { - suite.Suite.Fail("error checking accounts followed by: %v", err) - } - - _, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) - if err != nil { - suite.Suite.Fail("error checking localOnly accounts followed by: %v", err) - } + if err := suite.db.Put(ctx, followRequest); err != nil { + suite.FailNow(err.Error()) } + + follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(follow) + suite.Equal(followRequest.URI, follow.URI) } -func (suite *RelationshipTestSuite) CountAccountFollowedBy() { - suite.Suite.T().Skip("TODO: implement") +func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + + follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) + suite.ErrorIs(err, db.ErrNoEntries) + suite.Nil(follow) +} + +func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() { + ctx := context.Background() + account := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["admin_account"] + + // follow already exists in the db from local_account_1 -> admin_account + existingFollow := >smodel.Follow{} + if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil { + suite.FailNow(err.Error()) + } + + followRequest := >smodel.FollowRequest{ + ID: "01GEF753FWHCHRDWR0QEHBXM8W", + URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", + AccountID: account.ID, + TargetAccountID: targetAccount.ID, + } + + if err := suite.db.Put(ctx, followRequest); err != nil { + suite.FailNow(err.Error()) + } + + follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(follow) + + // uri should be equal to value of new/overlapping follow request + suite.NotEqual(followRequest.URI, existingFollow.URI) + suite.Equal(followRequest.URI, follow.URI) +} + +func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + + followRequest := >smodel.FollowRequest{ + ID: "01GEF753FWHCHRDWR0QEHBXM8W", + URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", + AccountID: account.ID, + TargetAccountID: targetAccount.ID, + } + + if err := suite.db.Put(ctx, followRequest); err != nil { + suite.FailNow(err.Error()) + } + + rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) + suite.NoError(err) + suite.NotNil(rejectedFollowRequest) +} + +func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + + rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) + suite.ErrorIs(err, db.ErrNoEntries) + suite.Nil(rejectedFollowRequest) +} + +func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() { + ctx := context.Background() + account := suite.testAccounts["admin_account"] + targetAccount := suite.testAccounts["local_account_2"] + + followRequest := >smodel.FollowRequest{ + ID: "01GEF753FWHCHRDWR0QEHBXM8W", + URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", + AccountID: account.ID, + TargetAccountID: targetAccount.ID, + } + + if err := suite.db.Put(ctx, followRequest); err != nil { + suite.FailNow(err.Error()) + } + + followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) + suite.NoError(err) + suite.Len(followRequests, 1) +} + +func (suite *RelationshipTestSuite) TestGetAccountFollows() { + account := suite.testAccounts["local_account_1"] + follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) + suite.NoError(err) + suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true) + suite.NoError(err) + suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollows() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false) + suite.NoError(err) + suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() { + account := suite.testAccounts["local_account_1"] + follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) + suite.NoError(err) + suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() { + account := suite.testAccounts["local_account_1"] + follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) + suite.NoError(err) + suite.Len(follows, 2) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false) + suite.NoError(err) + suite.Equal(2, followsCount) +} + +func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() { + account := suite.testAccounts["local_account_1"] + followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true) + suite.NoError(err) + suite.Equal(2, followsCount) } func TestRelationshipTestSuite(t *testing.T) { diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 46f24c4b2..9d2bac7a6 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db return u.cache.GetByID(id) }, func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx) + return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx) }, ) } @@ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts return u.cache.GetByAccountID(accountID) }, func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx) + return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx) }, ) } @@ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) return u.cache.GetByEmail(emailAddress) }, func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx) + return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx) }, ) } @@ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok return u.cache.GetByConfirmationToken(confirmationToken) }, func(user *gtsmodel.User) error { - return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx) + return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx) }, ) } diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index 434d12f32..34f7eb76f 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) { return } - if w.CaseInsensitive { - query = "LOWER(?) != LOWER(?)" - args = []interface{}{bun.Safe(w.Key), w.Value} - return - } - query = "? != ?" - args = []interface{}{bun.Safe(w.Key), w.Value} + args = []interface{}{bun.Ident(w.Key), w.Value} return } @@ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) { return } - if w.CaseInsensitive { - query = "LOWER(?) = LOWER(?)" - args = []interface{}{bun.Safe(w.Key), w.Value} - return - } - query = "? = ?" - args = []interface{}{bun.Safe(w.Key), w.Value} + args = []interface{}{bun.Ident(w.Key), w.Value} return } diff --git a/internal/db/params.go b/internal/db/params.go index d1809f1c4..84694d6d3 100644 --- a/internal/db/params.go +++ b/internal/db/params.go @@ -24,9 +24,6 @@ type Where struct { Key string // The value to match. Value interface{} - // Whether the value (if a string) should be case sensitive or not. - // Defaults to false. - CaseInsensitive bool // If set, reverse the where. // `WHERE k = v` becomes `WHERE k != v`. // `WHERE k IS NULL` becomes `WHERE k IS NOT NULL` diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go index fcc0cb480..a9b26a357 100644 --- a/internal/processing/admin/createdomainblock.go +++ b/internal/processing/admin/createdomainblock.go @@ -133,8 +133,10 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account } // if we have an instance account for this instance, delete it - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil { - l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err) + if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { + if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil { + l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err) + } } // delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines) diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go index b65954fe5..722c76b52 100644 --- a/internal/processing/admin/deletedomainblock.go +++ b/internal/processing/admin/deletedomainblock.go @@ -55,7 +55,7 @@ func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc // remove the domain block reference from the instance, if we have an entry for it i := >smodel.Instance{} if err := p.db.GetWhere(ctx, []db.Where{ - {Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true}, + {Key: "domain", Value: domainBlock.Domain}, {Key: "domain_block_id", Value: id}, }, i); err == nil { updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}