[chore] move client/federator workerpools to Workers{} (#1575)

* replace concurrency worker pools with base models in State.Workers, update code and tests accordingly

* improve code comment

* change back testrig default log level

* un-comment-out TestAnnounceTwice() and fix

---------

Signed-off-by: kim <grufwub@gmail.com>
Reviewed-by: tobi
This commit is contained in:
kim 2023-03-01 18:26:53 +00:00 committed by GitHub
parent 24cec4e7aa
commit baf933cb9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
130 changed files with 1037 additions and 1083 deletions

View file

@ -35,7 +35,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"go.uber.org/automaxprocs/maxprocs" "go.uber.org/automaxprocs/maxprocs"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
@ -45,7 +44,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/httpclient" "github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
@ -107,19 +105,11 @@ var Start action.GTSAction = func(ctx context.Context) error {
state.Workers.Start() state.Workers.Start()
defer state.Workers.Stop() defer state.Workers.Stop()
// Create the client API and federator worker pools
// NOTE: these MUST NOT be used until they are passed to the
// processor and it is started. The reason being that the processor
// sets the Worker process functions and start the underlying pools
// TODO: move these into state.Workers (and maybe reformat worker pools).
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// build backend handlers // build backend handlers
mediaManager := media.NewManager(&state) mediaManager := media.NewManager(&state)
oauthServer := oauth.New(ctx, dbService) oauthServer := oauth.New(ctx, dbService)
typeConverter := typeutils.NewConverter(dbService) typeConverter := typeutils.NewConverter(dbService)
federatingDB := federatingdb.New(dbService, fedWorker, typeConverter) federatingDB := federatingdb.New(&state, typeConverter)
transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client) transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client)
federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager) federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager)
@ -140,11 +130,15 @@ var Start action.GTSAction = func(ctx context.Context) error {
} }
// create the message processor using the other services we've created so far // create the message processor using the other services we've created so far
processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, storage, dbService, emailSender, clientWorker, fedWorker) processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender)
if err := processor.Start(); err != nil { if err := processor.Start(); err != nil {
return fmt.Errorf("error creating processor: %s", err) return fmt.Errorf("error creating processor: %s", err)
} }
// Set state client / federator worker enqueue functions
state.Workers.EnqueueClientAPI = processor.EnqueueClientAPI
state.Workers.EnqueueFederator = processor.EnqueueFederator
/* /*
HTTP router initialization HTTP router initialization
*/ */

View file

@ -33,14 +33,13 @@ import (
"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/api"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial" "github.com/superseriousbusiness/gotosocial/internal/gotosocial"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/web" "github.com/superseriousbusiness/gotosocial/internal/web"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -48,37 +47,44 @@ import (
// Start creates and starts a gotosocial testrig server // Start creates and starts a gotosocial testrig server
var Start action.GTSAction = func(ctx context.Context) error { var Start action.GTSAction = func(ctx context.Context) error {
var state state.State
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
dbService := testrig.NewTestDB() // Initialize caches
testrig.StandardDBSetup(dbService, nil) state.Caches.Init()
var storageBackend *storage.Driver state.Caches.Start()
if os.Getenv("GTS_STORAGE_BACKEND") == "s3" { defer state.Caches.Stop()
storageBackend, _ = storage.NewS3Storage()
} else {
storageBackend = testrig.NewInMemoryStorage()
}
testrig.StandardStorageSetup(storageBackend, "./testrig/media")
// Create client API and federator worker pools state.DB = testrig.NewTestDB(&state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) testrig.StandardDBSetup(state.DB, nil)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
if os.Getenv("GTS_STORAGE_BACKEND") == "s3" {
state.Storage, _ = storage.NewS3Storage()
} else {
state.Storage = testrig.NewInMemoryStorage()
}
testrig.StandardStorageSetup(state.Storage, "./testrig/media")
// Initialize workers.
state.Workers.Start()
defer state.Workers.Stop()
// build backend handlers // build backend handlers
transportController := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) {
r := io.NopCloser(bytes.NewReader([]byte{})) r := io.NopCloser(bytes.NewReader([]byte{}))
return &http.Response{ return &http.Response{
StatusCode: 200, StatusCode: 200,
Body: r, Body: r,
}, nil }, nil
}, ""), dbService, fedWorker) }, ""))
mediaManager := testrig.NewTestMediaManager(dbService, storageBackend) mediaManager := testrig.NewTestMediaManager(&state)
federator := testrig.NewTestFederator(dbService, transportController, storageBackend, mediaManager, fedWorker) federator := testrig.NewTestFederator(&state, transportController, mediaManager)
emailSender := testrig.NewEmailSender("./web/template/", nil) emailSender := testrig.NewEmailSender("./web/template/", nil)
processor := testrig.NewTestProcessor(dbService, storageBackend, federator, emailSender, mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&state, federator, emailSender, mediaManager)
if err := processor.Start(); err != nil { if err := processor.Start(); err != nil {
return fmt.Errorf("error starting processor: %s", err) return fmt.Errorf("error starting processor: %s", err)
} }
@ -87,7 +93,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
HTTP router initialization HTTP router initialization
*/ */
router := testrig.NewTestRouter(dbService) router := testrig.NewTestRouter(state.DB)
// attach global middlewares which are used for every request // attach global middlewares which are used for every request
router.AttachGlobalMiddleware( router.AttachGlobalMiddleware(
@ -112,7 +118,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
} }
} }
routerSession, err := dbService.GetSession(ctx) routerSession, err := state.DB.GetSession(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error retrieving router session for session middleware: %w", err) return fmt.Errorf("error retrieving router session for session middleware: %w", err)
} }
@ -123,13 +129,13 @@ var Start action.GTSAction = func(ctx context.Context) error {
} }
var ( var (
authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths authModule = api.NewAuth(state.DB, processor, idp, routerSession, sessionName) // auth/oauth paths
clientModule = api.NewClient(dbService, processor) // api client endpoints clientModule = api.NewClient(state.DB, processor) // api client endpoints
fileserverModule = api.NewFileserver(processor) // fileserver endpoints fileserverModule = api.NewFileserver(processor) // fileserver endpoints
wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints
nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint
activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints activityPubModule = api.NewActivityPub(state.DB, processor) // ActivityPub endpoints
webModule = web.New(dbService, processor) // web pages + user profiles + settings panels etc webModule = web.New(state.DB, processor) // web pages + user profiles + settings panels etc
) )
// these should be routed in order // these should be routed in order
@ -142,7 +148,7 @@ var Start action.GTSAction = func(ctx context.Context) error {
activityPubModule.RoutePublicKey(router) activityPubModule.RoutePublicKey(router)
webModule.Route(router) webModule.Route(router)
gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager) gts, err := gotosocial.NewServer(state.DB, router, federator, mediaManager)
if err != nil { if err != nil {
return fmt.Errorf("error creating gotosocial service: %s", err) return fmt.Errorf("error creating gotosocial service: %s", err)
} }
@ -157,8 +163,8 @@ var Start action.GTSAction = func(ctx context.Context) error {
sig := <-sigs sig := <-sigs
log.Infof(ctx, "received signal %s, shutting down", sig) log.Infof(ctx, "received signal %s, shutting down", sig)
testrig.StandardDBTeardown(dbService) testrig.StandardDBTeardown(state.DB)
testrig.StandardStorageTeardown(storageBackend) testrig.StandardStorageTeardown(state.Storage)
// close down all running services in order // close down all running services in order
if err := gts.Stop(ctx); err != nil { if err := gts.Stop(ctx); err != nil {

View file

@ -27,15 +27,14 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/emoji" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/emoji"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -50,6 +49,7 @@ type EmojiGetTestSuite struct {
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
storage *storage.Driver storage *storage.Driver
state state.State
testEmojis map[string]*gtsmodel.Emoji testEmojis map[string]*gtsmodel.Emoji
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
@ -65,19 +65,23 @@ func (suite *EmojiGetTestSuite) SetupSuite() {
} }
func (suite *EmojiGetTestSuite) SetupTest() { func (suite *EmojiGetTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.emojiModule = emoji.New(suite.processor) suite.emojiModule = emoji.New(suite.processor)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -90,6 +94,7 @@ func (suite *EmojiGetTestSuite) SetupTest() {
func (suite *EmojiGetTestSuite) TearDownTest() { func (suite *EmojiGetTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func (suite *EmojiGetTestSuite) TestGetEmoji() { func (suite *EmojiGetTestSuite) TestGetEmoji() {

View file

@ -34,11 +34,9 @@ import (
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -86,13 +84,10 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
suite.NoError(err) suite.NoError(err)
body := bytes.NewReader(bodyJson) body := bytes.NewReader(bodyJson)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
@ -190,13 +185,10 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
suite.NoError(err) suite.NoError(err)
body := bytes.NewReader(bodyJson) body := bytes.NewReader(bodyJson)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
@ -291,9 +283,6 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
suite.NoError(err) suite.NoError(err)
body := bytes.NewReader(bodyJson) body := bytes.NewReader(bodyJson)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// use a different version of the mock http client which serves the updated // use a different version of the mock http client which serves the updated
// version of the remote account, as though it had been updated there too; // version of the remote account, as though it had been updated there too;
// this is needed so it can be dereferenced + updated properly // this is needed so it can be dereferenced + updated properly
@ -301,10 +290,11 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
mockHTTPClient.TestRemotePeople = map[string]vocab.ActivityStreamsPerson{ mockHTTPClient.TestRemotePeople = map[string]vocab.ActivityStreamsPerson{
updatedAccount.URI: asAccount, updatedAccount.URI: asAccount,
} }
tc := testrig.NewTestTransportController(mockHTTPClient, suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) tc := testrig.NewTestTransportController(&suite.state, mockHTTPClient)
federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
@ -430,15 +420,12 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
suite.NoError(err) suite.NoError(err)
body := bytes.NewReader(bodyJson) body := bytes.NewReader(bodyJson)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
suite.NoError(processor.Start())
userModule := users.New(processor) userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request // setup request
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View file

@ -32,8 +32,6 @@ import (
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -104,13 +102,10 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"] signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"]
targetAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["local_account_1"]
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
@ -182,13 +177,10 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"] signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"]
targetAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["local_account_1"]
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())

View file

@ -33,8 +33,6 @@ import (
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -104,13 +102,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
targetAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"] targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())
@ -172,13 +167,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
targetAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["local_account_1"]
targetStatus := suite.testStatuses["local_account_1_status_1"] targetStatus := suite.testStatuses["local_account_1_status_1"]
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media"))
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager)
tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker)
federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker)
emailSender := testrig.NewEmailSender("../../../../web/template/", nil) emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor) userModule := users.New(processor)
suite.NoError(processor.Start()) suite.NoError(processor.Start())

View file

@ -22,15 +22,14 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -46,6 +45,7 @@ type UserStandardTestSuite struct {
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
storage *storage.Driver storage *storage.Driver
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -75,19 +75,21 @@ func (suite *UserStandardTestSuite) SetupSuite() {
} }
func (suite *UserStandardTestSuite) SetupTest() { func (suite *UserStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.userModule = users.New(suite.processor) suite.userModule = users.New(suite.processor)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -100,4 +102,5 @@ func (suite *UserStandardTestSuite) SetupTest() {
func (suite *UserStandardTestSuite) TearDownTest() { func (suite *UserStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }

View file

@ -28,17 +28,16 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/auth" "github.com/superseriousbusiness/gotosocial/internal/api/auth"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/oidc"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -47,6 +46,7 @@ type AuthStandardTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
storage *storage.Driver storage *storage.Driver
state state.State
mediaManager media.Manager mediaManager media.Manager
federator federation.Federator federator federation.Federator
processor *processing.Processor processor *processing.Processor
@ -78,18 +78,19 @@ func (suite *AuthStandardTestSuite) SetupSuite() {
} }
func (suite *AuthStandardTestSuite) SetupTest() { func (suite *AuthStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.authModule = auth.New(suite.db, suite.processor, suite.idp) suite.authModule = auth.New(suite.db, suite.processor, suite.idp)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
} }

View file

@ -27,16 +27,15 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -51,6 +50,7 @@ type AccountStandardTestSuite struct {
processor *processing.Processor processor *processing.Processor
emailSender email.Sender emailSender email.Sender
sentEmails map[string]string sentEmails map[string]string
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -76,19 +76,22 @@ func (suite *AccountStandardTestSuite) SetupSuite() {
} }
func (suite *AccountStandardTestSuite) SetupTest() { func (suite *AccountStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.accountsModule = accounts.New(suite.processor) suite.accountsModule = accounts.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -99,6 +102,7 @@ func (suite *AccountStandardTestSuite) SetupTest() {
func (suite *AccountStandardTestSuite) TearDownTest() { func (suite *AccountStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func (suite *AccountStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { func (suite *AccountStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context {

View file

@ -27,16 +27,15 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin" "github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -51,6 +50,7 @@ type AdminStandardTestSuite struct {
processor *processing.Processor processor *processing.Processor
emailSender email.Sender emailSender email.Sender
sentEmails map[string]string sentEmails map[string]string
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -82,19 +82,22 @@ func (suite *AdminStandardTestSuite) SetupSuite() {
} }
func (suite *AdminStandardTestSuite) SetupTest() { func (suite *AdminStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.adminModule = admin.New(suite.processor) suite.adminModule = admin.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -103,6 +106,7 @@ func (suite *AdminStandardTestSuite) SetupTest() {
func (suite *AdminStandardTestSuite) TearDownTest() { func (suite *AdminStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func (suite *AdminStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { func (suite *AdminStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context {

View file

@ -32,16 +32,15 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks" "github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks"
"github.com/superseriousbusiness/gotosocial/internal/api/client/statuses" "github.com/superseriousbusiness/gotosocial/internal/api/client/statuses"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -57,6 +56,7 @@ type BookmarkTestSuite struct {
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
storage *storage.Driver storage *storage.Driver
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -87,22 +87,25 @@ func (suite *BookmarkTestSuite) SetupSuite() {
} }
func (suite *BookmarkTestSuite) SetupTest() { func (suite *BookmarkTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.statusModule = statuses.New(suite.processor) suite.statusModule = statuses.New(suite.processor)
suite.bookmarkModule = bookmarks.New(suite.processor) suite.bookmarkModule = bookmarks.New(suite.processor)
@ -112,6 +115,7 @@ func (suite *BookmarkTestSuite) SetupTest() {
func (suite *BookmarkTestSuite) TearDownTest() { func (suite *BookmarkTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func (suite *BookmarkTestSuite) getBookmarks( func (suite *BookmarkTestSuite) getBookmarks(

View file

@ -21,14 +21,13 @@ package favourites_test
import ( import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/favourites" "github.com/superseriousbusiness/gotosocial/internal/api/client/favourites"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -44,6 +43,7 @@ type FavouritesStandardTestSuite struct {
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
storage *storage.Driver storage *storage.Driver
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -71,22 +71,25 @@ func (suite *FavouritesStandardTestSuite) SetupSuite() {
} }
func (suite *FavouritesStandardTestSuite) SetupTest() { func (suite *FavouritesStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.favModule = favourites.New(suite.processor) suite.favModule = favourites.New(suite.processor)
suite.NoError(suite.processor.Start()) suite.NoError(suite.processor.Start())
@ -95,6 +98,7 @@ func (suite *FavouritesStandardTestSuite) SetupTest() {
func (suite *FavouritesStandardTestSuite) TearDownTest() { func (suite *FavouritesStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func (suite *FavouritesStandardTestSuite) TestProcessFave() {} func (suite *FavouritesStandardTestSuite) TestProcessFave() {}

View file

@ -26,16 +26,15 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -48,6 +47,7 @@ type FollowRequestStandardTestSuite struct {
federator federation.Federator federator federation.Federator
processor *processing.Processor processor *processing.Processor
emailSender email.Sender emailSender email.Sender
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -73,18 +73,21 @@ func (suite *FollowRequestStandardTestSuite) SetupSuite() {
} }
func (suite *FollowRequestStandardTestSuite) SetupTest() { func (suite *FollowRequestStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.followRequestModule = followrequests.New(suite.processor) suite.followRequestModule = followrequests.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -95,6 +98,7 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
func (suite *FollowRequestStandardTestSuite) TearDownTest() { func (suite *FollowRequestStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func (suite *FollowRequestStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { func (suite *FollowRequestStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context {

View file

@ -26,16 +26,15 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/instance" "github.com/superseriousbusiness/gotosocial/internal/api/client/instance"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -50,6 +49,7 @@ type InstanceStandardTestSuite struct {
processor *processing.Processor processor *processing.Processor
emailSender email.Sender emailSender email.Sender
sentEmails map[string]string sentEmails map[string]string
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -75,19 +75,22 @@ func (suite *InstanceStandardTestSuite) SetupSuite() {
} }
func (suite *InstanceStandardTestSuite) SetupTest() { func (suite *InstanceStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.instanceModule = instance.New(suite.processor) suite.instanceModule = instance.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -96,6 +99,7 @@ func (suite *InstanceStandardTestSuite) SetupTest() {
func (suite *InstanceStandardTestSuite) TearDownTest() { func (suite *InstanceStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func (suite *InstanceStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, method string, path string, body []byte, contentType string, auth bool) *gin.Context { func (suite *InstanceStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, method string, path string, body []byte, contentType string, auth bool) *gin.Context {

View file

@ -33,7 +33,6 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
@ -41,9 +40,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -60,6 +59,7 @@ type MediaCreateTestSuite struct {
oauthServer oauth.Server oauthServer oauth.Server
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -78,21 +78,24 @@ type MediaCreateTestSuite struct {
*/ */
func (suite *MediaCreateTestSuite) SetupSuite() { func (suite *MediaCreateTestSuite) SetupSuite() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
// setup standard items // setup standard items
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
// setup module being tested // setup module being tested
suite.mediaModule = mediamodule.New(suite.processor) suite.mediaModule = mediamodule.New(suite.processor)
@ -102,11 +105,15 @@ func (suite *MediaCreateTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil { if err := suite.db.Stop(context.Background()); err != nil {
log.Panicf(nil, "error closing db connection: %s", err) log.Panicf(nil, "error closing db connection: %s", err)
} }
testrig.StopWorkers(&suite.state)
} }
func (suite *MediaCreateTestSuite) SetupTest() { func (suite *MediaCreateTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients() suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()

View file

@ -31,7 +31,6 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
@ -39,9 +38,9 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -58,6 +57,7 @@ type MediaUpdateTestSuite struct {
oauthServer oauth.Server oauthServer oauth.Server
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -76,21 +76,23 @@ type MediaUpdateTestSuite struct {
*/ */
func (suite *MediaUpdateTestSuite) SetupSuite() { func (suite *MediaUpdateTestSuite) SetupSuite() {
testrig.StartWorkers(&suite.state)
// setup standard items // setup standard items
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
// setup module being tested // setup module being tested
suite.mediaModule = mediamodule.New(suite.processor) suite.mediaModule = mediamodule.New(suite.processor)
@ -100,11 +102,15 @@ func (suite *MediaUpdateTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil { if err := suite.db.Stop(context.Background()); err != nil {
log.Panicf(nil, "error closing db connection: %s", err) log.Panicf(nil, "error closing db connection: %s", err)
} }
testrig.StopWorkers(&suite.state)
} }
func (suite *MediaUpdateTestSuite) SetupTest() { func (suite *MediaUpdateTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients() suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()

View file

@ -21,14 +21,13 @@ package reports_test
import ( import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/reports" "github.com/superseriousbusiness/gotosocial/internal/api/client/reports"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -42,6 +41,7 @@ type ReportsStandardTestSuite struct {
processor *processing.Processor processor *processing.Processor
emailSender email.Sender emailSender email.Sender
sentEmails map[string]string sentEmails map[string]string
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -67,19 +67,22 @@ func (suite *ReportsStandardTestSuite) SetupSuite() {
} }
func (suite *ReportsStandardTestSuite) SetupTest() { func (suite *ReportsStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.reportsModule = reports.New(suite.processor) suite.reportsModule = reports.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -90,4 +93,5 @@ func (suite *ReportsStandardTestSuite) SetupTest() {
func (suite *ReportsStandardTestSuite) TearDownTest() { func (suite *ReportsStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }

View file

@ -26,16 +26,15 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/search" "github.com/superseriousbusiness/gotosocial/internal/api/client/search"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -50,6 +49,7 @@ type SearchStandardTestSuite struct {
processor *processing.Processor processor *processing.Processor
emailSender email.Sender emailSender email.Sender
sentEmails map[string]string sentEmails map[string]string
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -71,19 +71,22 @@ func (suite *SearchStandardTestSuite) SetupSuite() {
} }
func (suite *SearchStandardTestSuite) SetupTest() { func (suite *SearchStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.searchModule = search.New(suite.processor) suite.searchModule = search.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -94,6 +97,7 @@ func (suite *SearchStandardTestSuite) SetupTest() {
func (suite *SearchStandardTestSuite) TearDownTest() { func (suite *SearchStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func (suite *SearchStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestPath string) *gin.Context { func (suite *SearchStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestPath string) *gin.Context {

View file

@ -21,14 +21,13 @@ package statuses_test
import ( import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/statuses" "github.com/superseriousbusiness/gotosocial/internal/api/client/statuses"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -44,6 +43,7 @@ type StatusStandardTestSuite struct {
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
storage *storage.Driver storage *storage.Driver
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -71,22 +71,26 @@ func (suite *StatusStandardTestSuite) SetupSuite() {
} }
func (suite *StatusStandardTestSuite) SetupTest() { func (suite *StatusStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.statusModule = statuses.New(suite.processor) suite.statusModule = statuses.New(suite.processor)
suite.NoError(suite.processor.Start()) suite.NoError(suite.processor.Start())
@ -95,4 +99,5 @@ func (suite *StatusStandardTestSuite) SetupTest() {
func (suite *StatusStandardTestSuite) TearDownTest() { func (suite *StatusStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }

View file

@ -32,15 +32,14 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -56,6 +55,7 @@ type StreamingTestSuite struct {
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
storage *storage.Driver storage *storage.Driver
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -83,22 +83,25 @@ func (suite *StreamingTestSuite) SetupSuite() {
} }
func (suite *StreamingTestSuite) SetupTest() { func (suite *StreamingTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.streamingModule = streaming.New(suite.processor, 1, 4096) suite.streamingModule = streaming.New(suite.processor, 1, 4096)
suite.NoError(suite.processor.Start()) suite.NoError(suite.processor.Start())
} }
@ -106,6 +109,7 @@ func (suite *StreamingTestSuite) SetupTest() {
func (suite *StreamingTestSuite) TearDownTest() { func (suite *StreamingTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
// Addr is a fake network interface which implements the net.Addr interface // Addr is a fake network interface which implements the net.Addr interface

View file

@ -21,14 +21,13 @@ package user_test
import ( import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/user" "github.com/superseriousbusiness/gotosocial/internal/api/client/user"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -43,6 +42,7 @@ type UserStandardTestSuite struct {
emailSender email.Sender emailSender email.Sender
processor *processing.Processor processor *processing.Processor
storage *storage.Driver storage *storage.Driver
state state.State
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client testClients map[string]*gtsmodel.Client
@ -56,23 +56,29 @@ type UserStandardTestSuite struct {
} }
func (suite *UserStandardTestSuite) SetupTest() { func (suite *UserStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients() suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications() suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers() suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()
suite.db = testrig.NewTestDB()
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.userModule = user.New(suite.processor) suite.userModule = user.New(suite.processor)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -83,4 +89,5 @@ func (suite *UserStandardTestSuite) SetupTest() {
func (suite *UserStandardTestSuite) TearDownTest() { func (suite *UserStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }

View file

@ -23,16 +23,15 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/fileserver" "github.com/superseriousbusiness/gotosocial/internal/api/fileserver"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -43,6 +42,7 @@ type FileserverTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
storage *storage.Driver storage *storage.Driver
state state.State
federator federation.Federator federator federation.Federator
tc typeutils.TypeConverter tc typeutils.TypeConverter
processor *processing.Processor processor *processing.Processor
@ -67,26 +67,32 @@ type FileserverTestSuite struct {
*/ */
func (suite *FileserverTestSuite) SetupSuite() { func (suite *FileserverTestSuite) SetupSuite() {
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.state.Storage = suite.storage
suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), suite.mediaManager)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker)
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)
suite.fileServer = fileserver.New(suite.processor) suite.fileServer = fileserver.New(suite.processor)
} }
func (suite *FileserverTestSuite) SetupTest() { func (suite *FileserverTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
suite.testTokens = testrig.NewTestTokens() suite.testTokens = testrig.NewTestTokens()
@ -101,9 +107,11 @@ func (suite *FileserverTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil { if err := suite.db.Stop(context.Background()); err != nil {
log.Panicf(nil, "error closing db connection: %s", err) log.Panicf(nil, "error closing db connection: %s", err)
} }
testrig.StopWorkers(&suite.state)
} }
func (suite *FileserverTestSuite) TearDownTest() { func (suite *FileserverTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }

View file

@ -26,15 +26,14 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger" "github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -44,6 +43,7 @@ type WebfingerStandardTestSuite struct {
// standard suite interfaces // standard suite interfaces
suite.Suite suite.Suite
db db.DB db db.DB
state state.State
tc typeutils.TypeConverter tc typeutils.TypeConverter
mediaManager media.Manager mediaManager media.Manager
federator federation.Federator federator federation.Federator
@ -76,19 +76,21 @@ func (suite *WebfingerStandardTestSuite) SetupSuite() {
} }
func (suite *WebfingerStandardTestSuite) SetupTest() { func (suite *WebfingerStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestLog() testrig.InitTestLog()
testrig.InitTestConfig() testrig.InitTestConfig()
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.state.DB = suite.db
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.webfingerModule = webfinger.New(suite.processor) suite.webfingerModule = webfinger.New(suite.processor)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(suite.db)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
@ -100,6 +102,7 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
func (suite *WebfingerStandardTestSuite) TearDownTest() { func (suite *WebfingerStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }
func accountDomainAccount() *gtsmodel.Account { func accountDomainAccount() *gtsmodel.Account {

View file

@ -30,9 +30,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger" "github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -91,9 +89,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo
config.SetHost("gts.example.org") config.SetHost("gts.example.org")
config.SetAccountDomain("example.org") config.SetAccountDomain("example.org")
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor) suite.webfingerModule = webfinger.New(suite.processor)
targetAccount := accountDomainAccount() targetAccount := accountDomainAccount()
@ -148,9 +144,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAc
config.SetHost("gts.example.org") config.SetHost("gts.example.org")
config.SetAccountDomain("example.org") config.SetAccountDomain("example.org")
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
suite.webfingerModule = webfinger.New(suite.processor) suite.webfingerModule = webfinger.New(suite.processor)
targetAccount := accountDomainAccount() targetAccount := accountDomainAccount()

View file

@ -1,141 +0,0 @@
/*
GoToSocial
Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package concurrency
import (
"context"
"errors"
"fmt"
"path"
"reflect"
"runtime"
"codeberg.org/gruf/go-kv"
"codeberg.org/gruf/go-runners"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources.
type WorkerPool[MsgType any] struct {
workers runners.WorkerPool
process func(context.Context, MsgType) error
nw, nq int
wtype string // contains worker type for logging
}
// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio,
// where the queue ratio is multiplied by no. workers to get queue size. If args < 1
// then suitable defaults are determined from the runtime's GOMAXPROCS variable.
func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] {
var zero MsgType
if workers < 1 {
// ensure sensible workers
workers = runtime.GOMAXPROCS(0) * 4
}
if queueRatio < 1 {
// ensure sensible ratio
queueRatio = 100
}
// Calculate the short type string for the msg type
msgType := reflect.TypeOf(zero).String()
_, msgType = path.Split(msgType)
w := &WorkerPool[MsgType]{
process: nil,
nw: workers,
nq: workers * queueRatio,
wtype: fmt.Sprintf("worker.Worker[%s]", msgType),
}
// Log new worker creation with worker type prefix
log.Infof(nil, "%s created with workers=%d queue=%d",
w.wtype,
workers,
workers*queueRatio,
)
return w
}
// Start will attempt to start the underlying worker pool, or return error.
func (w *WorkerPool[MsgType]) Start() error {
log.Infof(nil, "%s starting", w.wtype)
// Check processor was set
if w.process == nil {
return errors.New("nil Worker.process function")
}
// Attempt to start pool
if !w.workers.Start(w.nw, w.nq) {
return errors.New("failed to start Worker pool")
}
return nil
}
// Stop will attempt to stop the underlying worker pool, or return error.
func (w *WorkerPool[MsgType]) Stop() error {
log.Infof(nil, "%s stopping", w.wtype)
// Attempt to stop pool
if !w.workers.Stop() {
return errors.New("failed to stop Worker pool")
}
return nil
}
// SetProcessor will set the Worker's processor function, which is called for each queued message.
func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) {
if w.process != nil {
log.Panicf(nil, "%s Worker.process is already set", w.wtype)
}
w.process = fn
}
// Queue will queue provided message to be processed with there's a free worker.
func (w *WorkerPool[MsgType]) Queue(msg MsgType) {
log.Tracef(nil, "%s queueing message: %+v", w.wtype, msg)
// Create new process function for msg
process := func(ctx context.Context) {
if err := w.process(ctx, msg); err != nil {
log.WithContext(ctx).
WithFields(kv.Fields{
kv.Field{K: "type", V: w.wtype},
kv.Field{K: "error", V: err},
}...).Error("message processing error")
}
}
// Attempt a fast-enqueue of process
if !w.workers.EnqueueNow(process) {
// No spot acquired, log warning
log.WithFields(kv.Fields{
kv.Field{K: "type", V: w.wtype},
kv.Field{K: "queue", V: w.workers.Queue()},
}...).Warn("full worker queue")
// Block on enqueuing process func
w.workers.Enqueue(process)
}
}

View file

@ -70,8 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() {
} }
func (suite *AdminTestSuite) TestCreateInstanceAccount() { func (suite *AdminTestSuite) TestCreateInstanceAccount() {
// reinitialize test DB to clear caches // reinitialize db caches to clear
suite.db = testrig.NewTestDB() suite.state.Caches.Init()
// we need to take an empty db for this... // we need to take an empty db for this...
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
// ...with tables created but no data // ...with tables created but no data

View file

@ -22,13 +22,15 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
type BunDBStandardTestSuite struct { type BunDBStandardTestSuite struct {
// standard suite interfaces // standard suite interfaces
suite.Suite suite.Suite
db db.DB db db.DB
state state.State
// standard suite models // standard suite models
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
@ -61,9 +63,10 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
} }
func (suite *BunDBStandardTestSuite) SetupTest() { func (suite *BunDBStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
} }

View file

@ -21,11 +21,10 @@ package dereferencing_test
import ( import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" "github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -34,6 +33,7 @@ type DereferencerStandardTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
storage *storage.Driver storage *storage.Driver
state state.State
testRemoteStatuses map[string]vocab.ActivityStreamsNote testRemoteStatuses map[string]vocab.ActivityStreamsNote
testRemotePeople map[string]vocab.ActivityStreamsPerson testRemotePeople map[string]vocab.ActivityStreamsPerson
@ -58,12 +58,19 @@ func (suite *DereferencerStandardTestSuite) SetupTest() {
suite.testRemoteAttachments = testrig.NewTestFediAttachments("../../../testrig/media") suite.testRemoteAttachments = testrig.NewTestFediAttachments("../../../testrig/media")
suite.testEmojis = testrig.NewTestEmojis() suite.testEmojis = testrig.NewTestEmojis()
suite.db = testrig.NewTestDB() suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
suite.db = testrig.NewTestDB(&suite.state)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.dereferencer = dereferencing.NewDereferencer(suite.db, testrig.NewTestTypeConverter(suite.db), testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1)), testrig.NewTestMediaManager(suite.db, suite.storage)) suite.state.DB = suite.db
suite.state.Storage = suite.storage
media := testrig.NewTestMediaManager(&suite.state)
suite.dereferencer = dereferencing.NewDereferencer(suite.db, testrig.NewTestTypeConverter(suite.db), testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), media)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
} }
func (suite *DereferencerStandardTestSuite) TearDownTest() { func (suite *DereferencerStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StopWorkers(&suite.state)
} }

View file

@ -27,10 +27,8 @@ import (
"time" "time"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -56,14 +54,12 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() {
) )
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote) testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
// setup transport controller with a no-op client so we don't make external calls // setup transport controller with a no-op client so we don't make external calls
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested // setup module being tested
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity) activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity)
suite.NoError(err) suite.NoError(err)
@ -105,12 +101,10 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() {
) )
testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), testrig.TimeMustParse("2022-06-02T12:22:21+02:00"), testNote) testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), testrig.TimeMustParse("2022-06-02T12:22:21+02:00"), testNote)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested // setup module being tested
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity) activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity)
suite.NoError(err) suite.NoError(err)

View file

@ -65,7 +65,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if uris.IsFollowPath(acceptedObjectIRI) { if uris.IsFollowPath(acceptedObjectIRI) {
// ACCEPT FOLLOW // ACCEPT FOLLOW
gtsFollowRequest := &gtsmodel.FollowRequest{} gtsFollowRequest := &gtsmodel.FollowRequest{}
if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil { if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil {
return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err) return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err)
} }
@ -73,12 +73,12 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if gtsFollowRequest.AccountID != receivingAccount.ID { if gtsFollowRequest.AccountID != receivingAccount.ID {
return errors.New("ACCEPT: follow object account and inbox account were not the same") return errors.New("ACCEPT: follow object account and inbox account were not the same")
} }
follow, err := f.db.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID) follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID)
if err != nil { if err != nil {
return err return err
} }
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept, APActivityType: ap.ActivityAccept,
GTSModel: follow, GTSModel: follow,
@ -108,12 +108,12 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
if gtsFollow.AccountID != receivingAccount.ID { if gtsFollow.AccountID != receivingAccount.ID {
return errors.New("ACCEPT: follow object account and inbox account were not the same") return errors.New("ACCEPT: follow object account and inbox account were not the same")
} }
follow, err := f.db.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID) follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID)
if err != nil { if err != nil {
return err return err
} }
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept, APActivityType: ap.ActivityAccept,
GTSModel: follow, GTSModel: follow,

View file

@ -59,7 +59,7 @@ func (f *federatingDB) Announce(ctx context.Context, announce vocab.ActivityStre
} }
// it's a new announce so pass it back to the processor async for dereferencing etc // it's a new announce so pass it back to the processor async for dereferencing etc
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityAnnounce, APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: boost, GTSModel: boost,

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
) )
type AnnounceTestSuite struct { type AnnounceTestSuite struct {
@ -74,6 +75,13 @@ func (suite *AnnounceTestSuite) TestAnnounceTwice() {
suite.True(ok) suite.True(ok)
suite.Equal(announcingAccount.ID, boost.AccountID) suite.Equal(announcingAccount.ID, boost.AccountID)
// Insert the boost-of status into the
// DB cache to emulate processor handling
boost.ID, _ = id.NewULIDFromTime(boost.CreatedAt)
suite.state.Caches.GTS.Status().Store(boost, func() error {
return nil
})
// only the URI will be set on the boosted status because it still needs to be dereferenced // only the URI will be set on the boosted status because it still needs to be dereferenced
suite.NotEmpty(boost.BoostOf.URI) suite.NotEmpty(boost.BoostOf.URI)

View file

@ -103,11 +103,11 @@ func (f *federatingDB) activityBlock(ctx context.Context, asType vocab.Type, rec
block.ID = id.NewULID() block.ID = id.NewULID()
if err := f.db.PutBlock(ctx, block); err != nil { if err := f.state.DB.PutBlock(ctx, block); err != nil {
return fmt.Errorf("activityBlock: database error inserting block: %s", err) return fmt.Errorf("activityBlock: database error inserting block: %s", err)
} }
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityBlock, APObjectType: ap.ActivityBlock,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: block, GTSModel: block,
@ -202,7 +202,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream
return nil return nil
} }
// pass the note iri into the processor and have it do the dereferencing instead of doing it here // pass the note iri into the processor and have it do the dereferencing instead of doing it here
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
APIri: id.GetIRI(), APIri: id.GetIRI(),
@ -226,7 +226,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream
} }
status.ID = statusID status.ID = statusID
if err := f.db.PutStatus(ctx, status); err != nil { if err := f.state.DB.PutStatus(ctx, status); err != nil {
if errors.Is(err, db.ErrAlreadyExists) { if errors.Is(err, db.ErrAlreadyExists) {
// the status already exists in the database, which means we've already handled everything else, // the status already exists in the database, which means we've already handled everything else,
// so we can just return nil here and be done with it. // so we can just return nil here and be done with it.
@ -236,7 +236,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream
return fmt.Errorf("createNote: database error inserting status: %s", err) return fmt.Errorf("createNote: database error inserting status: %s", err)
} }
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: status, GTSModel: status,
@ -263,11 +263,11 @@ func (f *federatingDB) activityFollow(ctx context.Context, asType vocab.Type, re
followRequest.ID = id.NewULID() followRequest.ID = id.NewULID()
if err := f.db.Put(ctx, followRequest); err != nil { if err := f.state.DB.Put(ctx, followRequest); err != nil {
return fmt.Errorf("activityFollow: database error inserting follow request: %s", err) return fmt.Errorf("activityFollow: database error inserting follow request: %s", err)
} }
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: followRequest, GTSModel: followRequest,
@ -294,11 +294,11 @@ func (f *federatingDB) activityLike(ctx context.Context, asType vocab.Type, rece
fave.ID = id.NewULID() fave.ID = id.NewULID()
if err := f.db.Put(ctx, fave); err != nil { if err := f.state.DB.Put(ctx, fave); err != nil {
return fmt.Errorf("activityLike: database error inserting fave: %s", err) return fmt.Errorf("activityLike: database error inserting fave: %s", err)
} }
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityLike, APObjectType: ap.ActivityLike,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: fave, GTSModel: fave,
@ -325,11 +325,11 @@ func (f *federatingDB) activityFlag(ctx context.Context, asType vocab.Type, rece
report.ID = id.NewULID() report.ID = id.NewULID()
if err := f.db.PutReport(ctx, report); err != nil { if err := f.state.DB.PutReport(ctx, report); err != nil {
return fmt.Errorf("activityFlag: database error inserting report: %w", err) return fmt.Errorf("activityFlag: database error inserting report: %w", err)
} }
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ActivityFlag, APObjectType: ap.ActivityFlag,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: report, GTSModel: report,

View file

@ -24,9 +24,7 @@ import (
"codeberg.org/gruf/go-mutexes" "codeberg.org/gruf/go-mutexes"
"github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
) )
@ -43,17 +41,15 @@ type DB interface {
// It doesn't care what the underlying implementation of the DB interface is, as long as it works. // It doesn't care what the underlying implementation of the DB interface is, as long as it works.
type federatingDB struct { type federatingDB struct {
locks mutexes.MutexMap locks mutexes.MutexMap
db db.DB state *state.State
fedWorker *concurrency.WorkerPool[messages.FromFederator]
typeConverter typeutils.TypeConverter typeConverter typeutils.TypeConverter
} }
// New returns a DB interface using the given database and config // New returns a DB interface using the given database and config
func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator], tc typeutils.TypeConverter) DB { func New(state *state.State, tc typeutils.TypeConverter) DB {
fdb := federatingDB{ fdb := federatingDB{
locks: mutexes.NewMap(-1, -1), // use defaults locks: mutexes.NewMap(-1, -1), // use defaults
db: db, state: state,
fedWorker: fedWorker,
typeConverter: tc, typeConverter: tc,
} }
return &fdb return &fdb

View file

@ -51,9 +51,9 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
// in a delete we only get the URI, we can't know if we have a status or a profile or something else, // in a delete we only get the URI, we can't know if we have a status or a profile or something else,
// so we have to try a few different things... // so we have to try a few different things...
if s, err := f.db.GetStatusByURI(ctx, id.String()); err == nil && requestingAccount.ID == s.AccountID { if s, err := f.state.DB.GetStatusByURI(ctx, id.String()); err == nil && requestingAccount.ID == s.AccountID {
l.Debugf("uri is for STATUS with id: %s", s.ID) l.Debugf("uri is for STATUS with id: %s", s.ID)
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
GTSModel: s, GTSModel: s,
@ -61,9 +61,9 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
}) })
} }
if a, err := f.db.GetAccountByURI(ctx, id.String()); err == nil && requestingAccount.ID == a.ID { if a, err := f.state.DB.GetAccountByURI(ctx, id.String()); err == nil && requestingAccount.ID == a.ID {
l.Debugf("uri is for ACCOUNT with id %s", a.ID) l.Debugf("uri is for ACCOUNT with id %s", a.ID)
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
GTSModel: a, GTSModel: a,

View file

@ -23,11 +23,11 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -36,9 +36,9 @@ type FederatingDBTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
tc typeutils.TypeConverter tc typeutils.TypeConverter
fedWorker *concurrency.WorkerPool[messages.FromFederator]
fromFederator chan messages.FromFederator fromFederator chan messages.FromFederator
federatingDB federatingdb.DB federatingDB federatingdb.DB
state state.State
testTokens map[string]*gtsmodel.Token testTokens map[string]*gtsmodel.Token
testClients map[string]*gtsmodel.Client testClients map[string]*gtsmodel.Client
@ -66,22 +66,33 @@ func (suite *FederatingDBTestSuite) SetupTest() {
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
suite.fromFederator = make(chan messages.FromFederator, 10) suite.fromFederator = make(chan messages.FromFederator, 10)
suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error { suite.state.Workers.EnqueueFederator = func(ctx context.Context, msg messages.FromFederator) {
suite.fromFederator <- msg suite.fromFederator <- msg
return nil }
})
_ = suite.fedWorker.Start() suite.db = testrig.NewTestDB(&suite.state)
suite.db = testrig.NewTestDB()
suite.testActivities = testrig.NewTestActivities(suite.testAccounts) suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.federatingDB = testrig.NewTestFederatingDB(suite.db, suite.fedWorker) suite.federatingDB = testrig.NewTestFederatingDB(&suite.state)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
suite.state.DB = suite.db
} }
func (suite *FederatingDBTestSuite) TearDownTest() { func (suite *FederatingDBTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StopWorkers(&suite.state)
for suite.fromFederator != nil {
select {
case <-suite.fromFederator:
default:
return
}
}
} }
func createTestContext(receivingAccount *gtsmodel.Account, requestingAccount *gtsmodel.Account) context.Context { func createTestContext(receivingAccount *gtsmodel.Account, requestingAccount *gtsmodel.Account) context.Context {

View file

@ -29,7 +29,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
return nil, err return nil, err
} }
acctFollowers, err := f.db.GetAccountFollowedBy(ctx, acct.ID, false) acctFollowers, err := f.state.DB.GetAccountFollowedBy(ctx, acct.ID, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err) return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err)
} }
@ -37,7 +37,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
iris := []*url.URL{} iris := []*url.URL{}
for _, follow := range acctFollowers { for _, follow := range acctFollowers {
if follow.Account == nil { if follow.Account == nil {
a, err := f.db.GetAccountByID(ctx, follow.AccountID) a, err := f.state.DB.GetAccountByID(ctx, follow.AccountID)
if err != nil { if err != nil {
errWrapped := fmt.Errorf("Followers: db error getting account id %s: %s", follow.AccountID, err) errWrapped := fmt.Errorf("Followers: db error getting account id %s: %s", follow.AccountID, err)
if err == db.ErrNoEntries { if err == db.ErrNoEntries {

View file

@ -47,7 +47,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
return nil, err return nil, err
} }
acctFollowing, err := f.db.GetAccountFollows(ctx, acct.ID) acctFollowing, err := f.state.DB.GetAccountFollows(ctx, acct.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("Following: db error getting following for account id %s: %s", acct.ID, err) return nil, fmt.Errorf("Following: db error getting following for account id %s: %s", acct.ID, err)
} }
@ -55,7 +55,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
iris := []*url.URL{} iris := []*url.URL{}
for _, follow := range acctFollowing { for _, follow := range acctFollowing {
if follow.TargetAccount == nil { if follow.TargetAccount == nil {
a, err := f.db.GetAccountByID(ctx, follow.TargetAccountID) a, err := f.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil { if err != nil {
errWrapped := fmt.Errorf("Following: db error getting account id %s: %s", follow.TargetAccountID, err) errWrapped := fmt.Errorf("Following: db error getting account id %s: %s", follow.TargetAccountID, err)
if err == db.ErrNoEntries { if err == db.ErrNoEntries {

View file

@ -39,13 +39,13 @@ func (f *federatingDB) Get(ctx context.Context, id *url.URL) (value vocab.Type,
switch { switch {
case uris.IsUserPath(id): case uris.IsUserPath(id):
acct, err := f.db.GetAccountByURI(ctx, id.String()) acct, err := f.state.DB.GetAccountByURI(ctx, id.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return f.typeConverter.AccountToAS(ctx, acct) return f.typeConverter.AccountToAS(ctx, acct)
case uris.IsStatusesPath(id): case uris.IsStatusesPath(id):
status, err := f.db.GetStatusByURI(ctx, id.String()) status, err := f.state.DB.GetStatusByURI(ctx, id.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -85,12 +85,12 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
return nil, fmt.Errorf("couldn't extract local account username from uri %s: %s", iri, err) return nil, fmt.Errorf("couldn't extract local account username from uri %s: %s", iri, err)
} }
account, err := f.db.GetAccountByUsernameDomain(c, localAccountUsername, "") account, err := f.state.DB.GetAccountByUsernameDomain(c, localAccountUsername, "")
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err) return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err)
} }
follows, err := f.db.GetAccountFollowedBy(c, account.ID, false) follows, err := f.state.DB.GetAccountFollowedBy(c, account.ID, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err) return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err)
} }
@ -98,7 +98,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
for _, follow := range follows { for _, follow := range follows {
// make sure we retrieved the following account from the db // make sure we retrieved the following account from the db
if follow.Account == nil { if follow.Account == nil {
followingAccount, err := f.db.GetAccountByID(c, follow.AccountID) followingAccount, err := f.state.DB.GetAccountByID(c, follow.AccountID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
continue continue
@ -126,7 +126,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
} }
// check if this is just an account IRI... // check if this is just an account IRI...
if account, err := f.db.GetAccountByURI(c, iri.String()); err == nil { if account, err := f.state.DB.GetAccountByURI(c, iri.String()); err == nil {
// deliver to a shared inbox if we have that option // deliver to a shared inbox if we have that option
var inbox string var inbox string
if config.GetInstanceDeliverToSharedInboxes() && account.SharedInboxURI != nil && *account.SharedInboxURI != "" { if config.GetInstanceDeliverToSharedInboxes() && account.SharedInboxURI != nil && *account.SharedInboxURI != "" {

View file

@ -54,7 +54,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
} }
status, err := f.db.GetStatusByURI(ctx, uid) status, err := f.state.DB.GetStatusByURI(ctx, uid)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are no entries for this status // there are no entries for this status
@ -71,7 +71,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
} }
if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
@ -88,7 +88,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
} }
if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
@ -105,7 +105,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
} }
if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
@ -122,7 +122,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)
} }
if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
@ -130,7 +130,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
// an actual error happened // an actual error happened
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
} }
if err := f.db.GetByID(ctx, likeID, &gtsmodel.StatusFave{}); err != nil { if err := f.state.DB.GetByID(ctx, likeID, &gtsmodel.StatusFave{}); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are no entries // there are no entries
return false, nil return false, nil
@ -147,7 +147,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err) return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)
} }
if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are no entries for this username // there are no entries for this username
return false, nil return false, nil
@ -155,7 +155,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) {
// an actual error happened // an actual error happened
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
} }
if err := f.db.GetByID(ctx, blockID, &gtsmodel.Block{}); err != nil { if err := f.state.DB.GetByID(ctx, blockID, &gtsmodel.Block{}); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are no entries // there are no entries
return false, nil return false, nil

View file

@ -64,7 +64,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
if uris.IsFollowPath(rejectedObjectIRI) { if uris.IsFollowPath(rejectedObjectIRI) {
// REJECT FOLLOW // REJECT FOLLOW
gtsFollowRequest := &gtsmodel.FollowRequest{} gtsFollowRequest := &gtsmodel.FollowRequest{}
if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil { if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil {
return fmt.Errorf("Reject: couldn't get follow request with id %s from the database: %s", rejectedObjectIRI.String(), err) return fmt.Errorf("Reject: couldn't get follow request with id %s from the database: %s", rejectedObjectIRI.String(), err)
} }
@ -73,7 +73,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
return errors.New("Reject: follow object account and inbox account were not the same") return errors.New("Reject: follow object account and inbox account were not the same")
} }
if _, err := f.db.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil { if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil {
return err return err
} }
@ -102,7 +102,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
if gtsFollow.AccountID != receivingAccount.ID { if gtsFollow.AccountID != receivingAccount.ID {
return errors.New("Reject: follow object account and inbox account were not the same") return errors.New("Reject: follow object account and inbox account were not the same")
} }
if _, err := f.db.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil { if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil {
return err return err
} }

View file

@ -81,11 +81,11 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: follow object account and inbox account were not the same") return errors.New("UNDO: follow object account and inbox account were not the same")
} }
// delete any existing FOLLOW // delete any existing FOLLOW
if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.Follow{}); err != nil { if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.Follow{}); err != nil {
return fmt.Errorf("UNDO: db error removing follow: %s", err) return fmt.Errorf("UNDO: db error removing follow: %s", err)
} }
// delete any existing FOLLOW REQUEST // delete any existing FOLLOW REQUEST
if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.FollowRequest{}); err != nil { if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, &gtsmodel.FollowRequest{}); err != nil {
return fmt.Errorf("UNDO: db error removing follow request: %s", err) return fmt.Errorf("UNDO: db error removing follow request: %s", err)
} }
l.Debug("follow undone") l.Debug("follow undone")
@ -114,7 +114,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: block object account and inbox account were not the same") return errors.New("UNDO: block object account and inbox account were not the same")
} }
// delete any existing BLOCK // delete any existing BLOCK
if err := f.db.DeleteBlockByURI(ctx, gtsBlock.URI); err != nil { if err := f.state.DB.DeleteBlockByURI(ctx, gtsBlock.URI); err != nil {
return fmt.Errorf("UNDO: db error removing block: %s", err) return fmt.Errorf("UNDO: db error removing block: %s", err)
} }
l.Debug("block undone") l.Debug("block undone")

View file

@ -138,7 +138,7 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {
// pass to the processor for further updating of eg., avatar/header, emojis // pass to the processor for further updating of eg., avatar/header, emojis
// the actual db insert/update will take place a bit later // the actual db insert/update will take place a bit later
f.fedWorker.Queue(messages.FromFederator{ f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate, APActivityType: ap.ActivityUpdate,
GTSModel: updatedAcct, GTSModel: updatedAcct,

View file

@ -95,7 +95,7 @@ func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (idURL *url.URL,
// take the IRI of the first actor we can find (there should only be one) // take the IRI of the first actor we can find (there should only be one)
if iter.IsIRI() { if iter.IsIRI() {
// if there's an error here, just use the fallback behavior -- we don't need to return an error here // if there's an error here, just use the fallback behavior -- we don't need to return an error here
if actorAccount, err := f.db.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil { if actorAccount, err := f.state.DB.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil {
newID, err := id.NewRandomULID() newID, err := id.NewRandomULID()
if err != nil { if err != nil {
return nil, err return nil, err
@ -238,7 +238,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
switch { switch {
case uris.IsUserPath(iri): case uris.IsUserPath(iri):
if acct, err = f.db.GetAccountByURI(ctx, iri.String()); err != nil { if acct, err = f.state.DB.GetAccountByURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to uri %s", iri.String()) return nil, fmt.Errorf("no actor found that corresponds to uri %s", iri.String())
} }
@ -246,7 +246,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
} }
return acct, nil return acct, nil
case uris.IsInboxPath(iri): case uris.IsInboxPath(iri):
if err = f.db.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil { if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", iri.String()) return nil, fmt.Errorf("no actor found that corresponds to inbox %s", iri.String())
} }
@ -254,7 +254,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
} }
return acct, nil return acct, nil
case uris.IsOutboxPath(iri): case uris.IsOutboxPath(iri):
if err = f.db.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil { if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", iri.String()) return nil, fmt.Errorf("no actor found that corresponds to outbox %s", iri.String())
} }
@ -262,7 +262,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
} }
return acct, nil return acct, nil
case uris.IsFollowersPath(iri): case uris.IsFollowersPath(iri):
if err = f.db.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil { if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to followers_uri %s", iri.String()) return nil, fmt.Errorf("no actor found that corresponds to followers_uri %s", iri.String())
} }
@ -270,7 +270,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
} }
return acct, nil return acct, nil
case uris.IsFollowingPath(iri): case uris.IsFollowingPath(iri):
if err = f.db.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil { if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to following_uri %s", iri.String()) return nil, fmt.Errorf("no actor found that corresponds to following_uri %s", iri.String())
} }

View file

@ -28,10 +28,8 @@ import (
"github.com/go-fed/httpsig" "github.com/go-fed/httpsig"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
) )
@ -43,12 +41,10 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook1() {
// the activity we're gonna use // the activity we're gonna use
activity := suite.testActivities["dm_for_zork"] activity := suite.testActivities["dm_for_zork"]
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested // setup module being tested
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
// setup request // setup request
ctx := context.Background() ctx := context.Background()
@ -74,13 +70,11 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook2() {
// the activity we're gonna use // the activity we're gonna use
activity := suite.testActivities["reply_to_turtle_for_zork"] activity := suite.testActivities["reply_to_turtle_for_zork"]
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested // setup module being tested
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
// setup request // setup request
ctx := context.Background() ctx := context.Background()
@ -107,13 +101,11 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook3() {
// the activity we're gonna use // the activity we're gonna use
activity := suite.testActivities["reply_to_turtle_for_turtle"] activity := suite.testActivities["reply_to_turtle_for_turtle"]
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
// setup module being tested // setup module being tested
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
// setup request // setup request
ctx := context.Background() ctx := context.Background()
@ -142,13 +134,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() {
sendingAccount := suite.testAccounts["remote_account_1"] sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"] inboxAccount := suite.testAccounts["local_account_1"]
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
// now setup module being tested, with the mock transport controller // now setup module being tested, with the mock transport controller
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil) request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)
// we need these headers for the request to be validated // we need these headers for the request to be validated
@ -187,13 +177,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGone() {
activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"] activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"]
inboxAccount := suite.testAccounts["local_account_1"] inboxAccount := suite.testAccounts["local_account_1"]
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
// now setup module being tested, with the mock transport controller // now setup module being tested, with the mock transport controller
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil) request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)
// we need these headers for the request to be validated // we need these headers for the request to be validated
@ -231,13 +219,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGoneNoTombstoneYet
activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"] activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"]
inboxAccount := suite.testAccounts["local_account_1"] inboxAccount := suite.testAccounts["local_account_1"]
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
// now setup module being tested, with the mock transport controller // now setup module being tested, with the mock transport controller
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil) request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil)
// we need these headers for the request to be validated // we need these headers for the request to be validated
@ -271,10 +257,9 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGoneNoTombstoneYet
} }
func (suite *FederatingProtocolTestSuite) TestBlocked1() { func (suite *FederatingProtocolTestSuite) TestBlocked1() {
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
sendingAccount := suite.testAccounts["remote_account_1"] sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"] inboxAccount := suite.testAccounts["local_account_1"]
@ -294,10 +279,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked1() {
} }
func (suite *FederatingProtocolTestSuite) TestBlocked2() { func (suite *FederatingProtocolTestSuite) TestBlocked2() {
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
sendingAccount := suite.testAccounts["remote_account_1"] sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"] inboxAccount := suite.testAccounts["local_account_1"]
@ -328,10 +312,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked2() {
} }
func (suite *FederatingProtocolTestSuite) TestBlocked3() { func (suite *FederatingProtocolTestSuite) TestBlocked3() {
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
sendingAccount := suite.testAccounts["remote_account_1"] sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"] inboxAccount := suite.testAccounts["local_account_1"]
@ -365,10 +348,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked3() {
} }
func (suite *FederatingProtocolTestSuite) TestBlocked4() { func (suite *FederatingProtocolTestSuite) TestBlocked4() {
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media")
tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) tc := testrig.NewTestTransportController(&suite.state, httpClient)
federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state))
sendingAccount := suite.testAccounts["remote_account_1"] sendingAccount := suite.testAccounts["remote_account_1"]
inboxAccount := suite.testAccounts["local_account_1"] inboxAccount := suite.testAccounts["local_account_1"]

View file

@ -23,6 +23,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -32,6 +33,7 @@ type FederatorStandardTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
storage *storage.Driver storage *storage.Driver
state state.State
tc typeutils.TypeConverter tc typeutils.TypeConverter
testAccounts map[string]*gtsmodel.Account testAccounts map[string]*gtsmodel.Account
testStatuses map[string]*gtsmodel.Status testStatuses map[string]*gtsmodel.Status
@ -42,8 +44,9 @@ type FederatorStandardTestSuite struct {
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *FederatorStandardTestSuite) SetupSuite() { func (suite *FederatorStandardTestSuite) SetupSuite() {
// setup standard items // setup standard items
testrig.StartWorkers(&suite.state)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.state.Storage = suite.storage
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()
suite.testStatuses = testrig.NewTestStatuses() suite.testStatuses = testrig.NewTestStatuses()
suite.testTombstones = testrig.NewTestTombstones() suite.testTombstones = testrig.NewTestTombstones()
@ -52,7 +55,10 @@ func (suite *FederatorStandardTestSuite) SetupSuite() {
func (suite *FederatorStandardTestSuite) SetupTest() { func (suite *FederatorStandardTestSuite) SetupTest() {
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.state.Caches.Init()
suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.state.DB = suite.db
suite.testActivities = testrig.NewTestActivities(suite.testAccounts) suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
} }

View file

@ -20,11 +20,10 @@ package media_test
import ( import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
@ -35,6 +34,7 @@ type MediaStandardTestSuite struct {
db db.DB db db.DB
storage *storage.Driver storage *storage.Driver
state state.State
manager media.Manager manager media.Manager
transportController transport.Controller transportController transport.Controller
testAttachments map[string]*gtsmodel.MediaAttachment testAttachments map[string]*gtsmodel.MediaAttachment
@ -46,21 +46,27 @@ func (suite *MediaStandardTestSuite) SetupSuite() {
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.DB = suite.db
suite.state.Storage = suite.storage
} }
func (suite *MediaStandardTestSuite) SetupTest() { func (suite *MediaStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.StandardStorageSetup(suite.storage, "../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
suite.testAttachments = testrig.NewTestAttachments() suite.testAttachments = testrig.NewTestAttachments()
suite.testAccounts = testrig.NewTestAccounts() suite.testAccounts = testrig.NewTestAccounts()
suite.testEmojis = testrig.NewTestEmojis() suite.testEmojis = testrig.NewTestEmojis()
suite.manager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.manager = testrig.NewTestMediaManager(&suite.state)
suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](0, 0)) suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../testrig/media"))
} }
func (suite *MediaStandardTestSuite) TearDownTest() { func (suite *MediaStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }

View file

@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/gotosocial/testrig"
"github.com/superseriousbusiness/oauth2/v4/models" "github.com/superseriousbusiness/oauth2/v4/models"
) )
@ -32,6 +33,7 @@ import (
type PgClientStoreTestSuite struct { type PgClientStoreTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
state state.State
testClientID string testClientID string
testClientSecret string testClientSecret string
testClientDomain string testClientDomain string
@ -48,9 +50,11 @@ func (suite *PgClientStoreTestSuite) SetupSuite() {
// SetupTest creates a postgres connection and creates the oauth_clients table before each test // SetupTest creates a postgres connection and creates the oauth_clients table before each test
func (suite *PgClientStoreTestSuite) SetupTest() { func (suite *PgClientStoreTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.InitTestLog() testrig.InitTestLog()
testrig.InitTestConfig() testrig.InitTestConfig()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
} }

View file

@ -19,13 +19,11 @@
package account package account
import ( import (
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/text"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/internal/visibility"
@ -35,35 +33,32 @@ import (
// //
// It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc. // It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc.
type Processor struct { type Processor struct {
state *state.State
tc typeutils.TypeConverter tc typeutils.TypeConverter
mediaManager media.Manager mediaManager media.Manager
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
oauthServer oauth.Server oauthServer oauth.Server
filter visibility.Filter filter visibility.Filter
formatter text.Formatter formatter text.Formatter
db db.DB
federator federation.Federator federator federation.Federator
parseMention gtsmodel.ParseMentionFunc parseMention gtsmodel.ParseMentionFunc
} }
// New returns a new account processor. // New returns a new account processor.
func New( func New(
db db.DB, state *state.State,
tc typeutils.TypeConverter, tc typeutils.TypeConverter,
mediaManager media.Manager, mediaManager media.Manager,
oauthServer oauth.Server, oauthServer oauth.Server,
clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
federator federation.Federator, federator federation.Federator,
parseMention gtsmodel.ParseMentionFunc, parseMention gtsmodel.ParseMentionFunc,
) Processor { ) Processor {
return Processor{ return Processor{
state: state,
tc: tc, tc: tc,
mediaManager: mediaManager, mediaManager: mediaManager,
clientWorker: clientWorker,
oauthServer: oauthServer, oauthServer: oauthServer,
filter: visibility.NewFilter(db), filter: visibility.NewFilter(state.DB),
formatter: text.NewFormatter(db), formatter: text.NewFormatter(state.DB),
db: db,
federator: federator, federator: federator,
parseMention: parseMention, parseMention: parseMention,
} }

View file

@ -22,7 +22,6 @@ import (
"context" "context"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
@ -32,6 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/processing/account" "github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
@ -44,6 +44,7 @@ type AccountStandardTestSuite struct {
db db.DB db db.DB
tc typeutils.TypeConverter tc typeutils.TypeConverter
storage *storage.Driver storage *storage.Driver
state state.State
mediaManager media.Manager mediaManager media.Manager
oauthServer oauth.Server oauthServer oauth.Server
fromClientAPIChan chan messages.FromClientAPI fromClientAPIChan chan messages.FromClientAPI
@ -76,30 +77,30 @@ func (suite *AccountStandardTestSuite) SetupSuite() {
} }
func (suite *AccountStandardTestSuite) SetupTest() { func (suite *AccountStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestLog() testrig.InitTestLog()
testrig.InitTestConfig() testrig.InitTestConfig()
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.db = testrig.NewTestDB(&suite.state)
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.state.DB = suite.db
clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
suite.fromClientAPIChan <- msg
return nil
})
_ = fedWorker.Start()
_ = clientWorker.Start()
suite.db = testrig.NewTestDB()
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.fromClientAPIChan = make(chan messages.FromClientAPI, 100) suite.fromClientAPIChan = make(chan messages.FromClientAPI, 100)
suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker) suite.state.Workers.EnqueueClientAPI = func(ctx context.Context, msg messages.FromClientAPI) {
suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker) suite.fromClientAPIChan <- msg
}
suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
suite.sentEmails = make(map[string]string) suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
suite.accountProcessor = account.New(suite.db, suite.tc, suite.mediaManager, suite.oauthServer, clientWorker, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator)) suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
} }
@ -107,4 +108,5 @@ func (suite *AccountStandardTestSuite) SetupTest() {
func (suite *AccountStandardTestSuite) TearDownTest() { func (suite *AccountStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db) testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage) testrig.StandardStorageTeardown(suite.storage)
testrig.StopWorkers(&suite.state)
} }

View file

@ -36,13 +36,13 @@ import (
// BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local. // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local.
func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db // make sure the target account actually exists in our db
targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
} }
// if requestingAccount already blocks target account, we don't need to do anything // if requestingAccount already blocks target account, we don't need to do anything
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil { if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err))
} else if blocked { } else if blocked {
return p.RelationshipGet(ctx, requestingAccount, targetAccountID) return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
@ -64,18 +64,18 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
block.URI = uris.GenerateURIForBlock(requestingAccount.Username, newBlockID) block.URI = uris.GenerateURIForBlock(requestingAccount.Username, newBlockID)
// whack it in the database // whack it in the database
if err := p.db.PutBlock(ctx, block); err != nil { if err := p.state.DB.PutBlock(ctx, block); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err))
} }
// clear any follows or follow requests from the blocked account to the target account -- this is a simple delete // clear any follows or follow requests from the blocked account to the target account -- this is a simple delete
if err := p.db.DeleteWhere(ctx, []db.Where{ if err := p.state.DB.DeleteWhere(ctx, []db.Where{
{Key: "account_id", Value: targetAccountID}, {Key: "account_id", Value: targetAccountID},
{Key: "target_account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: requestingAccount.ID},
}, &gtsmodel.Follow{}); err != nil { }, &gtsmodel.Follow{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err))
} }
if err := p.db.DeleteWhere(ctx, []db.Where{ if err := p.state.DB.DeleteWhere(ctx, []db.Where{
{Key: "account_id", Value: targetAccountID}, {Key: "account_id", Value: targetAccountID},
{Key: "target_account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: requestingAccount.ID},
}, &gtsmodel.FollowRequest{}); err != nil { }, &gtsmodel.FollowRequest{}); err != nil {
@ -89,12 +89,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
var frChanged bool var frChanged bool
var frURI string var frURI string
fr := &gtsmodel.FollowRequest{} fr := &gtsmodel.FollowRequest{}
if err := p.db.GetWhere(ctx, []db.Where{ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID}, {Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID}, {Key: "target_account_id", Value: targetAccountID},
}, fr); err == nil { }, fr); err == nil {
frURI = fr.URI frURI = fr.URI
if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err))
} }
frChanged = true frChanged = true
@ -104,12 +104,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
var fChanged bool var fChanged bool
var fURI string var fURI string
f := &gtsmodel.Follow{} f := &gtsmodel.Follow{}
if err := p.db.GetWhere(ctx, []db.Where{ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID}, {Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID}, {Key: "target_account_id", Value: targetAccountID},
}, f); err == nil { }, f); err == nil {
fURI = f.URI fURI = f.URI
if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err))
} }
fChanged = true fChanged = true
@ -117,7 +117,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
// follow request status changed so send the UNDO activity to the channel for async processing // follow request status changed so send the UNDO activity to the channel for async processing
if frChanged { if frChanged {
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo, APActivityType: ap.ActivityUndo,
GTSModel: &gtsmodel.Follow{ GTSModel: &gtsmodel.Follow{
@ -132,7 +132,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
// follow status changed so send the UNDO activity to the channel for async processing // follow status changed so send the UNDO activity to the channel for async processing
if fChanged { if fChanged {
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo, APActivityType: ap.ActivityUndo,
GTSModel: &gtsmodel.Follow{ GTSModel: &gtsmodel.Follow{
@ -146,7 +146,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
} }
// handle the rest of the block process asynchronously // handle the rest of the block process asynchronously
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityBlock, APObjectType: ap.ActivityBlock,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: block, GTSModel: block,
@ -160,23 +160,23 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
// BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local. // BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local.
func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db // make sure the target account actually exists in our db
targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
} }
// check if a block exists, and remove it if it does // check if a block exists, and remove it if it does
block, err := p.db.GetBlock(ctx, requestingAccount.ID, targetAccountID) block, err := p.state.DB.GetBlock(ctx, requestingAccount.ID, targetAccountID)
if err == nil { if err == nil {
// we got a block, remove it // we got a block, remove it
block.Account = requestingAccount block.Account = requestingAccount
block.TargetAccount = targetAccount block.TargetAccount = targetAccount
if err := p.db.DeleteBlockByID(ctx, block.ID); err != nil { if err := p.state.DB.DeleteBlockByID(ctx, block.ID); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err))
} }
// send the UNDO activity to the client worker for async processing // send the UNDO activity to the client worker for async processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityBlock, APObjectType: ap.ActivityBlock,
APActivityType: ap.ActivityUndo, APActivityType: ap.ActivityUndo,
GTSModel: block, GTSModel: block,

View file

@ -34,7 +34,7 @@ import (
// BookmarksGet returns a pageable response of statuses that are bookmarked by requestingAccount. // BookmarksGet returns a pageable response of statuses that are bookmarked by requestingAccount.
// Paging for this response is done based on bookmark ID rather than status ID. // Paging for this response is done based on bookmark ID rather than status ID.
func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmodel.Account, limit int, maxID string, minID string) (*apimodel.PageableResponse, gtserror.WithCode) { func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmodel.Account, limit int, maxID string, minID string) (*apimodel.PageableResponse, gtserror.WithCode) {
bookmarks, err := p.db.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID) bookmarks, err := p.state.DB.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -47,7 +47,7 @@ func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmode
) )
for _, bookmark := range bookmarks { for _, bookmark := range bookmarks {
status, err := p.db.GetStatusByID(ctx, bookmark.StatusID) status, err := p.state.DB.GetStatusByID(ctx, bookmark.StatusID)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNoEntries) { if errors.Is(err, db.ErrNoEntries) {
// We just don't have the status for some reason. // We just don't have the status for some reason.

View file

@ -35,7 +35,7 @@ import (
// Create processes the given form for creating a new account, returning an oauth token for that account if successful. // Create processes the given form for creating a new account, returning an oauth token for that account if successful.
func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, gtserror.WithCode) { func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, gtserror.WithCode) {
emailAvailable, err := p.db.IsEmailAvailable(ctx, form.Email) emailAvailable, err := p.state.DB.IsEmailAvailable(ctx, form.Email)
if err != nil { if err != nil {
return nil, gtserror.NewErrorBadRequest(err) return nil, gtserror.NewErrorBadRequest(err)
} }
@ -43,7 +43,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", form.Email)) return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", form.Email))
} }
usernameAvailable, err := p.db.IsUsernameAvailable(ctx, form.Username) usernameAvailable, err := p.state.DB.IsUsernameAvailable(ctx, form.Username)
if err != nil { if err != nil {
return nil, gtserror.NewErrorBadRequest(err) return nil, gtserror.NewErrorBadRequest(err)
} }
@ -61,7 +61,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
} }
log.Trace(ctx, "creating new username and account") log.Trace(ctx, "creating new username and account")
user, err := p.db.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false) user, err := p.state.DB.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error creating new signup in the database: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error creating new signup in the database: %s", err))
} }
@ -73,7 +73,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
} }
if user.Account == nil { if user.Account == nil {
a, err := p.db.GetAccountByID(ctx, user.AccountID) a, err := p.state.DB.GetAccountByID(ctx, user.AccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting new account from the database: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting new account from the database: %s", err))
} }
@ -82,7 +82,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
// there are side effects for creating a new account (sending confirmation emails etc) // there are side effects for creating a new account (sending confirmation emails etc)
// so pass a message to the processor so that it can do it asynchronously // so pass a message to the processor so that it can do it asynchronously
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: user.Account, GTSModel: user.Account,

View file

@ -54,22 +54,22 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
if account.Domain == "" { if account.Domain == "" {
// see if we can get a user for this account // see if we can get a user for this account
var err error var err error
if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil { if user, err = p.state.DB.GetUserByAccountID(ctx, account.ID); err == nil {
// we got one! select all tokens with the user's ID // we got one! select all tokens with the user's ID
tokens := []*gtsmodel.Token{} tokens := []*gtsmodel.Token{}
if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil { if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {
// we have some tokens to delete // we have some tokens to delete
for _, t := range tokens { for _, t := range tokens {
// delete client(s) associated with this token // delete client(s) associated with this token
if err := p.db.DeleteByID(ctx, t.ClientID, &gtsmodel.Client{}); err != nil { if err := p.state.DB.DeleteByID(ctx, t.ClientID, &gtsmodel.Client{}); err != nil {
l.Errorf("error deleting oauth client: %s", err) l.Errorf("error deleting oauth client: %s", err)
} }
// delete application(s) associated with this token // delete application(s) associated with this token
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &gtsmodel.Application{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &gtsmodel.Application{}); err != nil {
l.Errorf("error deleting application: %s", err) l.Errorf("error deleting application: %s", err)
} }
// delete the token itself // delete the token itself
if err := p.db.DeleteByID(ctx, t.ID, t); err != nil { if err := p.state.DB.DeleteByID(ctx, t.ID, t); err != nil {
l.Errorf("error deleting oauth token: %s", err) l.Errorf("error deleting oauth token: %s", err)
} }
} }
@ -80,12 +80,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// 2. Delete account's blocks // 2. Delete account's blocks
l.Trace("deleting account blocks") l.Trace("deleting account blocks")
// first delete any blocks that this account created // first delete any blocks that this account created
if err := p.db.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil { if err := p.state.DB.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil {
l.Errorf("error deleting blocks created by account: %s", err) l.Errorf("error deleting blocks created by account: %s", err)
} }
// now delete any blocks that target this account // now delete any blocks that target this account
if err := p.db.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil { if err := p.state.DB.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil {
l.Errorf("error deleting blocks targeting account: %s", err) l.Errorf("error deleting blocks targeting account: %s", err)
} }
@ -96,12 +96,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// TODO: federate these if necessary // TODO: federate these if necessary
l.Trace("deleting account follow requests") l.Trace("deleting account follow requests")
// first delete any follow requests that this account created // first delete any follow requests that this account created
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
l.Errorf("error deleting follow requests created by account: %s", err) l.Errorf("error deleting follow requests created by account: %s", err)
} }
// now delete any follow requests that target this account // now delete any follow requests that target this account
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
l.Errorf("error deleting follow requests targeting account: %s", err) l.Errorf("error deleting follow requests targeting account: %s", err)
} }
@ -109,12 +109,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// TODO: federate these if necessary // TODO: federate these if necessary
l.Trace("deleting account follows") l.Trace("deleting account follows")
// first delete any follows that this account created // first delete any follows that this account created
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
l.Errorf("error deleting follows created by account: %s", err) l.Errorf("error deleting follows created by account: %s", err)
} }
// now delete any follows that target this account // now delete any follows that target this account
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
l.Errorf("error deleting follows targeting account: %s", err) l.Errorf("error deleting follows targeting account: %s", err)
} }
@ -129,7 +129,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
for { for {
// Fetch next block of account statuses from database // Fetch next block of account statuses from database
statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false) statuses, err := p.state.DB.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
// an actual error has occurred // an actual error has occurred
@ -149,7 +149,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
l.Tracef("queue client API status delete: %s", status.ID) l.Tracef("queue client API status delete: %s", status.ID)
// pass the status delete through the client api channel for processing // pass the status delete through the client api channel for processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
GTSModel: status, GTSModel: status,
@ -158,7 +158,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
}) })
// Look for any boosts of this status in DB // Look for any boosts of this status in DB
boosts, err := p.db.GetStatusReblogs(ctx, status) boosts, err := p.state.DB.GetStatusReblogs(ctx, status)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
l.Errorf("error fetching status reblogs for %q: %v", status.ID, err) l.Errorf("error fetching status reblogs for %q: %v", status.ID, err)
continue continue
@ -167,7 +167,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
for _, boost := range boosts { for _, boost := range boosts {
if boost.Account == nil { if boost.Account == nil {
// Fetch the relevant account for this status boost // Fetch the relevant account for this status boost
boostAcc, err := p.db.GetAccountByID(ctx, boost.AccountID) boostAcc, err := p.state.DB.GetAccountByID(ctx, boost.AccountID)
if err != nil { if err != nil {
l.Errorf("error fetching boosted status account for %q: %v", boost.AccountID, err) l.Errorf("error fetching boosted status account for %q: %v", boost.AccountID, err)
continue continue
@ -180,7 +180,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
l.Tracef("queue client API boost delete: %s", status.ID) l.Tracef("queue client API boost delete: %s", status.ID)
// pass the boost delete through the client api channel for processing // pass the boost delete through the client api channel for processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityAnnounce, APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityUndo, APActivityType: ap.ActivityUndo,
GTSModel: status, GTSModel: status,
@ -197,31 +197,31 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// 10. Delete account's notifications // 10. Delete account's notifications
l.Trace("deleting account notifications") l.Trace("deleting account notifications")
// first notifications created by account // first notifications created by account
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
l.Errorf("error deleting notifications created by account: %s", err) l.Errorf("error deleting notifications created by account: %s", err)
} }
// now notifications targeting account // now notifications targeting account
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
l.Errorf("error deleting notifications targeting account: %s", err) l.Errorf("error deleting notifications targeting account: %s", err)
} }
// 11. Delete account's bookmarks // 11. Delete account's bookmarks
l.Trace("deleting account bookmarks") l.Trace("deleting account bookmarks")
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
l.Errorf("error deleting bookmarks created by account: %s", err) l.Errorf("error deleting bookmarks created by account: %s", err)
} }
// 12. Delete account's faves // 12. Delete account's faves
// TODO: federate these if necessary // TODO: federate these if necessary
l.Trace("deleting account faves") l.Trace("deleting account faves")
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
l.Errorf("error deleting faves created by account: %s", err) l.Errorf("error deleting faves created by account: %s", err)
} }
// 13. Delete account's mutes // 13. Delete account's mutes
l.Trace("deleting account mutes") l.Trace("deleting account mutes")
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
l.Errorf("error deleting status mutes created by account: %s", err) l.Errorf("error deleting status mutes created by account: %s", err)
} }
@ -234,7 +234,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// 16. Delete account's user // 16. Delete account's user
if user != nil { if user != nil {
l.Trace("deleting account user") l.Trace("deleting account user")
if err := p.db.DeleteUserByID(ctx, user.ID); err != nil { if err := p.state.DB.DeleteUserByID(ctx, user.ID); err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
} }
@ -261,7 +261,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
account.Discoverable = &discoverable account.Discoverable = &discoverable
account.SuspendedAt = time.Now() account.SuspendedAt = time.Now()
account.SuspensionOrigin = origin account.SuspensionOrigin = origin
err := p.db.UpdateAccount(ctx, account) err := p.state.DB.UpdateAccount(ctx, account)
if err != nil { if err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
@ -281,7 +281,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
if form.DeleteOriginID == account.ID { if form.DeleteOriginID == account.ID {
// the account owner themself has requested deletion via the API, get their user from the db // the account owner themself has requested deletion via the API, get their user from the db
user, err := p.db.GetUserByAccountID(ctx, account.ID) user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)
if err != nil { if err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
@ -301,7 +301,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
} else { } else {
// the delete has been requested by some other account, grab it; // the delete has been requested by some other account, grab it;
// if we've reached this point we know it has permission already // if we've reached this point we know it has permission already
requestingAccount, err := p.db.GetAccountByID(ctx, form.DeleteOriginID) requestingAccount, err := p.state.DB.GetAccountByID(ctx, form.DeleteOriginID)
if err != nil { if err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
@ -310,7 +310,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
} }
// put the delete in the processor queue to handle the rest of it asynchronously // put the delete in the processor queue to handle the rest of it asynchronously
p.clientWorker.Queue(fromClientAPIMessage) p.state.Workers.EnqueueClientAPI(ctx, fromClientAPIMessage)
return nil return nil
} }

View file

@ -35,14 +35,14 @@ import (
// FollowCreate handles a follow request to an account, either remote or local. // FollowCreate handles a follow request to an account, either remote or local.
func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't create the request ofc // if there's a block between the accounts we shouldn't create the request ofc
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil { if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} else if blocked { } else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
} }
// make sure the target account actually exists in our db // make sure the target account actually exists in our db
targetAcct, err := p.db.GetAccountByID(ctx, form.ID) targetAcct, err := p.state.DB.GetAccountByID(ctx, form.ID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
@ -51,7 +51,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
} }
// check if a follow exists already // check if a follow exists already
if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil { if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err))
} else if follows { } else if follows {
// already follows so just return the relationship // already follows so just return the relationship
@ -59,7 +59,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
} }
// check if a follow request exists already // check if a follow request exists already
if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil { if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err))
} else if followRequested { } else if followRequested {
// already follow requested so just return the relationship // already follow requested so just return the relationship
@ -95,13 +95,13 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
} }
// whack it in the database // whack it in the database
if err := p.db.Put(ctx, fr); err != nil { if err := p.state.DB.Put(ctx, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err))
} }
// if it's a local account that's not locked we can just straight up accept the follow request // if it's a local account that's not locked we can just straight up accept the follow request
if !*targetAcct.Locked && targetAcct.Domain == "" { if !*targetAcct.Locked && targetAcct.Domain == "" {
if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err))
} }
// return the new relationship // return the new relationship
@ -109,7 +109,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
} }
// otherwise we leave the follow request as it is and we handle the rest of the process asynchronously // otherwise we leave the follow request as it is and we handle the rest of the process asynchronously
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: fr, GTSModel: fr,
@ -124,7 +124,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
// FollowRemove handles the removal of a follow/follow request to an account, either remote or local. // FollowRemove handles the removal of a follow/follow request to an account, either remote or local.
func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't do anything // if there's a block between the accounts we shouldn't do anything
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -133,7 +133,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
} }
// make sure the target account actually exists in our db // make sure the target account actually exists in our db
targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID) targetAcct, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
@ -144,12 +144,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
var frChanged bool var frChanged bool
var frURI string var frURI string
fr := &gtsmodel.FollowRequest{} fr := &gtsmodel.FollowRequest{}
if err := p.db.GetWhere(ctx, []db.Where{ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID}, {Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID}, {Key: "target_account_id", Value: targetAccountID},
}, fr); err == nil { }, fr); err == nil {
frURI = fr.URI frURI = fr.URI
if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err))
} }
frChanged = true frChanged = true
@ -159,12 +159,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
var fChanged bool var fChanged bool
var fURI string var fURI string
f := &gtsmodel.Follow{} f := &gtsmodel.Follow{}
if err := p.db.GetWhere(ctx, []db.Where{ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "account_id", Value: requestingAccount.ID}, {Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID}, {Key: "target_account_id", Value: targetAccountID},
}, f); err == nil { }, f); err == nil {
fURI = f.URI fURI = f.URI
if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err))
} }
fChanged = true fChanged = true
@ -172,7 +172,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
// follow request status changed so send the UNDO activity to the channel for async processing // follow request status changed so send the UNDO activity to the channel for async processing
if frChanged { if frChanged {
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo, APActivityType: ap.ActivityUndo,
GTSModel: &gtsmodel.Follow{ GTSModel: &gtsmodel.Follow{
@ -187,7 +187,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
// follow status changed so send the UNDO activity to the channel for async processing // follow status changed so send the UNDO activity to the channel for async processing
if fChanged { if fChanged {
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo, APActivityType: ap.ActivityUndo,
GTSModel: &gtsmodel.Follow{ GTSModel: &gtsmodel.Follow{

View file

@ -33,7 +33,7 @@ import (
// Get processes the given request for account information. // Get processes the given request for account information.
func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, gtserror.WithCode) { func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, gtserror.WithCode) {
targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(errors.New("account not found")) return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
@ -46,7 +46,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account
// GetLocalByUsername processes the given request for account information targeting a local account by username. // GetLocalByUsername processes the given request for account information targeting a local account by username.
func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *gtsmodel.Account, username string) (*apimodel.Account, gtserror.WithCode) { func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *gtsmodel.Account, username string) (*apimodel.Account, gtserror.WithCode) {
targetAccount, err := p.db.GetAccountByUsernameDomain(ctx, username, "") targetAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "")
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(errors.New("account not found")) return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
@ -59,7 +59,7 @@ func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *g
// GetCustomCSSForUsername returns custom css for the given local username. // GetCustomCSSForUsername returns custom css for the given local username.
func (p *Processor) GetCustomCSSForUsername(ctx context.Context, username string) (string, gtserror.WithCode) { func (p *Processor) GetCustomCSSForUsername(ctx context.Context, username string) (string, gtserror.WithCode) {
customCSS, err := p.db.GetAccountCustomCSSByUsername(ctx, username) customCSS, err := p.state.DB.GetAccountCustomCSSByUsername(ctx, username)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return "", gtserror.NewErrorNotFound(errors.New("account not found")) return "", gtserror.NewErrorNotFound(errors.New("account not found"))
@ -74,7 +74,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco
var blocked bool var blocked bool
var err error var err error
if requestingAccount != nil { if requestingAccount != nil {
blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true) blocked, err = p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking account block: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking account block: %s", err))
} }

View file

@ -31,14 +31,14 @@ import (
// FollowersGet fetches a list of the target account's followers. // FollowersGet fetches a list of the target account's followers.
func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} else if blocked { } else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
} }
accounts := []apimodel.Account{} accounts := []apimodel.Account{}
follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false) follows, err := p.state.DB.GetAccountFollowedBy(ctx, targetAccountID, false)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return accounts, nil return accounts, nil
@ -47,7 +47,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
} }
for _, f := range follows { for _, f := range follows {
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -56,7 +56,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
} }
if f.Account == nil { if f.Account == nil {
a, err := p.db.GetAccountByID(ctx, f.AccountID) a, err := p.state.DB.GetAccountByID(ctx, f.AccountID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
continue continue
@ -77,14 +77,14 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
// FollowingGet fetches a list of the accounts that target account is following. // FollowingGet fetches a list of the accounts that target account is following.
func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} else if blocked { } else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
} }
accounts := []apimodel.Account{} accounts := []apimodel.Account{}
follows, err := p.db.GetAccountFollows(ctx, targetAccountID) follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return accounts, nil return accounts, nil
@ -93,7 +93,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
} }
for _, f := range follows { for _, f := range follows {
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -102,7 +102,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
} }
if f.TargetAccount == nil { if f.TargetAccount == nil {
a, err := p.db.GetAccountByID(ctx, f.TargetAccountID) a, err := p.state.DB.GetAccountByID(ctx, f.TargetAccountID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
continue continue
@ -127,7 +127,7 @@ func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsm
return nil, gtserror.NewErrorForbidden(errors.New("not authed")) return nil, gtserror.NewErrorForbidden(errors.New("not authed"))
} }
gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID) gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err))
} }

View file

@ -34,7 +34,7 @@ const rssFeedLength = 20
// GetRSSFeedForUsername returns RSS feed for the given local username. // GetRSSFeedForUsername returns RSS feed for the given local username.
func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) (func() (string, gtserror.WithCode), time.Time, gtserror.WithCode) { func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) (func() (string, gtserror.WithCode), time.Time, gtserror.WithCode) {
account, err := p.db.GetAccountByUsernameDomain(ctx, username, "") account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "")
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account not found")) return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account not found"))
@ -46,13 +46,13 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)
return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account RSS feed not enabled")) return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account RSS feed not enabled"))
} }
lastModified, err := p.db.GetAccountLastPosted(ctx, account.ID, true) lastModified, err := p.state.DB.GetAccountLastPosted(ctx, account.ID, true)
if err != nil { if err != nil {
return nil, time.Time{}, gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err)) return nil, time.Time{}, gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err))
} }
return func() (string, gtserror.WithCode) { return func() (string, gtserror.WithCode) {
statuses, err := p.db.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "") statuses, err := p.state.DB.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "")
if err != nil && err != db.ErrNoEntries { if err != nil && err != db.ErrNoEntries {
return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err)) return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err))
} }
@ -65,7 +65,7 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)
var image *feeds.Image var image *feeds.Image
if account.AvatarMediaAttachmentID != "" { if account.AvatarMediaAttachmentID != "" {
if account.AvatarMediaAttachment == nil { if account.AvatarMediaAttachment == nil {
avatar, err := p.db.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID) avatar, err := p.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
if err != nil { if err != nil {
return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error fetching avatar attachment: %s", err)) return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error fetching avatar attachment: %s", err))
} }

View file

@ -33,7 +33,7 @@ import (
// the account given in authed. // the account given in authed.
func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) { func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) {
if requestingAccount != nil { if requestingAccount != nil {
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} else if blocked { } else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
@ -46,10 +46,10 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel
) )
if pinned { if pinned {
// Get *ONLY* pinned statuses. // Get *ONLY* pinned statuses.
statuses, err = p.db.GetAccountPinnedStatuses(ctx, targetAccountID) statuses, err = p.state.DB.GetAccountPinnedStatuses(ctx, targetAccountID)
} else { } else {
// Get account statuses which *may* include pinned ones. // Get account statuses which *may* include pinned ones.
statuses, err = p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly) statuses, err = p.state.DB.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly)
} }
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
@ -120,7 +120,7 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel
// WebStatusesGet fetches a number of statuses (in descending order) from the given account. It selects only // WebStatusesGet fetches a number of statuses (in descending order) from the given account. It selects only
// statuses which are suitable for showing on the public web profile of an account. // statuses which are suitable for showing on the public web profile of an account.
func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, maxID string) (*apimodel.PageableResponse, gtserror.WithCode) { func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, maxID string) (*apimodel.PageableResponse, gtserror.WithCode) {
acct, err := p.db.GetAccountByID(ctx, targetAccountID) acct, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
err := fmt.Errorf("account %s not found in the db, not getting web statuses for it", targetAccountID) err := fmt.Errorf("account %s not found in the db, not getting web statuses for it", targetAccountID)
@ -134,7 +134,7 @@ func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string,
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
} }
statuses, err := p.db.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID) statuses, err := p.state.DB.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return util.EmptyPageableResponse(), nil return util.EmptyPageableResponse(), nil

View file

@ -165,12 +165,12 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, form
account.EnableRSS = form.EnableRSS account.EnableRSS = form.EnableRSS
} }
err := p.db.UpdateAccount(ctx, account) err := p.state.DB.UpdateAccount(ctx, account)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err))
} }
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityUpdate, APActivityType: ap.ActivityUpdate,
GTSModel: account, GTSModel: account,

View file

@ -31,7 +31,7 @@ import (
) )
func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account, form *apimodel.AdminAccountActionRequest) gtserror.WithCode { func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account, form *apimodel.AdminAccountActionRequest) gtserror.WithCode {
targetAccount, err := p.db.GetAccountByID(ctx, form.TargetAccountID) targetAccount, err := p.state.DB.GetAccountByID(ctx, form.TargetAccountID)
if err != nil { if err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
@ -47,7 +47,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account
case string(gtsmodel.AdminActionSuspend): case string(gtsmodel.AdminActionSuspend):
adminAction.Type = gtsmodel.AdminActionSuspend adminAction.Type = gtsmodel.AdminActionSuspend
// pass the account delete through the client api channel for processing // pass the account delete through the client api channel for processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActorPerson, APObjectType: ap.ActorPerson,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
OriginAccount: account, OriginAccount: account,
@ -57,7 +57,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account
return gtserror.NewErrorBadRequest(fmt.Errorf("admin action type %s is not supported for this endpoint", form.Type)) return gtserror.NewErrorBadRequest(fmt.Errorf("admin action type %s is not supported for this endpoint", form.Type))
} }
if err := p.db.Put(ctx, adminAction); err != nil { if err := p.state.DB.Put(ctx, adminAction); err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }

View file

@ -19,32 +19,25 @@
package admin package admin
import ( import (
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
) )
type Processor struct { type Processor struct {
state *state.State
tc typeutils.TypeConverter tc typeutils.TypeConverter
mediaManager media.Manager mediaManager media.Manager
transportController transport.Controller transportController transport.Controller
storage *storage.Driver
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
db db.DB
} }
// New returns a new admin processor. // New returns a new admin processor.
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor { func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor {
return Processor{ return Processor{
state: state,
tc: tc, tc: tc,
mediaManager: mediaManager, mediaManager: mediaManager,
transportController: transportController, transportController: transportController,
storage: storage,
clientWorker: clientWorker,
db: db,
} }
} }

View file

@ -28,7 +28,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc
domain = strings.ToLower(domain) domain = strings.ToLower(domain)
// first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work // first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work
block, err := p.db.GetDomainBlock(ctx, domain) block, err := p.state.DB.GetDomainBlock(ctx, domain)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
// something went wrong in the DB // something went wrong in the DB
@ -47,7 +47,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc
} }
// Insert the new block into the database // Insert the new block into the database
if err := p.db.CreateDomainBlock(ctx, newBlock); err != nil { if err := p.state.DB.CreateDomainBlock(ctx, newBlock); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err))
} }
@ -80,7 +80,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
// if we have an instance entry for this domain, update it with the new block ID and clear all fields // if we have an instance entry for this domain, update it with the new block ID and clear all fields
instance := &gtsmodel.Instance{} instance := &gtsmodel.Instance{}
if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil { if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil {
updatingColumns := []string{ updatingColumns := []string{
"title", "title",
"updated_at", "updated_at",
@ -105,15 +105,15 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
instance.ContactAccountUsername = "" instance.ContactAccountUsername = ""
instance.ContactAccountID = "" instance.ContactAccountID = ""
instance.Version = "" instance.Version = ""
if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil { if err := p.state.DB.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err) l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
} }
l.Debug("domainBlockProcessSideEffects: instance entry updated") l.Debug("domainBlockProcessSideEffects: instance entry updated")
} }
// if we have an instance account for this instance, delete it // if we have an instance account for this instance, delete it
if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { if instanceAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil { if err := p.state.DB.DeleteAccount(ctx, instanceAccount.ID); err != nil {
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err) l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
} }
} }
@ -125,7 +125,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
selectAccountsLoop: selectAccountsLoop:
for { for {
accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit) accounts, err := p.state.DB.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// no accounts left for this instance so we're done // no accounts left for this instance so we're done
@ -141,7 +141,7 @@ selectAccountsLoop:
l.Debugf("putting delete for account %s in the clientAPI channel", a.Username) l.Debugf("putting delete for account %s in the clientAPI channel", a.Username)
// pass the account delete through the client api channel for processing // pass the account delete through the client api channel for processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActorPerson, APObjectType: ap.ActorPerson,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
GTSModel: block, GTSModel: block,
@ -195,7 +195,7 @@ func (p *Processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Ac
func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
domainBlocks := []*gtsmodel.DomainBlock{} domainBlocks := []*gtsmodel.DomainBlock{}
if err := p.db.GetAll(ctx, &domainBlocks); err != nil { if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong // something has gone really wrong
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -219,7 +219,7 @@ func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Accou
func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := &gtsmodel.DomainBlock{} domainBlock := &gtsmodel.DomainBlock{}
if err := p.db.GetByID(ctx, id, domainBlock); err != nil { if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong // something has gone really wrong
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -240,7 +240,7 @@ func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Accoun
func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) { func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
domainBlock := &gtsmodel.DomainBlock{} domainBlock := &gtsmodel.DomainBlock{}
if err := p.db.GetByID(ctx, id, domainBlock); err != nil { if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
// something has gone really wrong // something has gone really wrong
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -256,13 +256,13 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
} }
// Delete the domain block // Delete the domain block
if err := p.db.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil { if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// remove the domain block reference from the instance, if we have an entry for it // remove the domain block reference from the instance, if we have an entry for it
i := &gtsmodel.Instance{} i := &gtsmodel.Instance{}
if err := p.db.GetWhere(ctx, []db.Where{ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "domain", Value: domainBlock.Domain}, {Key: "domain", Value: domainBlock.Domain},
{Key: "domain_block_id", Value: id}, {Key: "domain_block_id", Value: id},
}, i); err == nil { }, i); err == nil {
@ -270,21 +270,21 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
i.SuspendedAt = time.Time{} i.SuspendedAt = time.Time{}
i.DomainBlockID = "" i.DomainBlockID = ""
i.UpdatedAt = time.Now() i.UpdatedAt = time.Now()
if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
} }
} }
// unsuspend all accounts whose suspension origin was this domain block // unsuspend all accounts whose suspension origin was this domain block
// 1. remove the 'suspended_at' entry from their accounts // 1. remove the 'suspended_at' entry from their accounts
if err := p.db.UpdateWhere(ctx, []db.Where{ if err := p.state.DB.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID}, {Key: "suspension_origin", Value: domainBlock.ID},
}, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil { }, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err))
} }
// 2. remove the 'suspension_origin' entry from their accounts // 2. remove the 'suspension_origin' entry from their accounts
if err := p.db.UpdateWhere(ctx, []db.Where{ if err := p.state.DB.UpdateWhere(ctx, []db.Where{
{Key: "suspension_origin", Value: domainBlock.ID}, {Key: "suspension_origin", Value: domainBlock.ID},
}, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil { }, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err))

View file

@ -42,7 +42,7 @@ func (p *Processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account,
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin") return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
} }
maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "") maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "")
if maybeExisting != nil { if maybeExisting != nil {
return nil, gtserror.NewErrorConflict(fmt.Errorf("emoji with shortcode %s already exists", form.Shortcode), fmt.Sprintf("emoji with shortcode %s already exists", form.Shortcode)) return nil, gtserror.NewErrorConflict(fmt.Errorf("emoji with shortcode %s already exists", form.Shortcode), fmt.Sprintf("emoji with shortcode %s already exists", form.Shortcode))
} }
@ -110,7 +110,7 @@ func (p *Processor) EmojisGet(
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin") return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
} }
emojis, err := p.db.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit) emojis, err := p.state.DB.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
err := fmt.Errorf("EmojisGet: db error: %s", err) err := fmt.Errorf("EmojisGet: db error: %s", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -176,7 +176,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin") return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
} }
emoji, err := p.db.GetEmojiByID(ctx, id) emoji, err := p.state.DB.GetEmojiByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNoEntries) { if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("EmojiGet: no emoji with id %s found in the db", id) err = fmt.Errorf("EmojiGet: no emoji with id %s found in the db", id)
@ -197,7 +197,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use
// EmojiDelete deletes one emoji from the database, with the given id. // EmojiDelete deletes one emoji from the database, with the given id.
func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.AdminEmoji, gtserror.WithCode) { func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.AdminEmoji, gtserror.WithCode) {
emoji, err := p.db.GetEmojiByID(ctx, id) emoji, err := p.state.DB.GetEmojiByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNoEntries) { if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("EmojiDelete: no emoji with id %s found in the db", id) err = fmt.Errorf("EmojiDelete: no emoji with id %s found in the db", id)
@ -218,7 +218,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
if err := p.db.DeleteEmojiByID(ctx, id); err != nil { if err := p.state.DB.DeleteEmojiByID(ctx, id); err != nil {
err := fmt.Errorf("EmojiDelete: db error: %s", err) err := fmt.Errorf("EmojiDelete: db error: %s", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -228,7 +228,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin
// EmojiUpdate updates one emoji with the given id, using the provided form parameters. // EmojiUpdate updates one emoji with the given id, using the provided form parameters.
func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.EmojiUpdateRequest) (*apimodel.AdminEmoji, gtserror.WithCode) { func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.EmojiUpdateRequest) (*apimodel.AdminEmoji, gtserror.WithCode) {
emoji, err := p.db.GetEmojiByID(ctx, id) emoji, err := p.state.DB.GetEmojiByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNoEntries) { if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("EmojiUpdate: no emoji with id %s found in the db", id) err = fmt.Errorf("EmojiUpdate: no emoji with id %s found in the db", id)
@ -253,7 +253,7 @@ func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.E
// EmojiCategoriesGet returns all custom emoji categories that exist on this instance. // EmojiCategoriesGet returns all custom emoji categories that exist on this instance.
func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCategory, gtserror.WithCode) { func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCategory, gtserror.WithCode) {
categories, err := p.db.GetEmojiCategories(ctx) categories, err := p.state.DB.GetEmojiCategories(ctx)
if err != nil { if err != nil {
err := fmt.Errorf("EmojiCategoriesGet: db error: %s", err) err := fmt.Errorf("EmojiCategoriesGet: db error: %s", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -277,7 +277,7 @@ func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCa
*/ */
func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) { func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) {
category, err := p.db.GetEmojiCategoryByName(ctx, name) category, err := p.state.DB.GetEmojiCategoryByName(ctx, name)
if err == nil { if err == nil {
return category, nil return category, nil
} }
@ -299,7 +299,7 @@ func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (
Name: name, Name: name,
} }
if err := p.db.PutEmojiCategory(ctx, category); err != nil { if err := p.state.DB.PutEmojiCategory(ctx, category); err != nil {
err = fmt.Errorf("GetOrCreateEmojiCategory: error putting new emoji category in the database: %s", err) err = fmt.Errorf("GetOrCreateEmojiCategory: error putting new emoji category in the database: %s", err)
return nil, err return nil, err
} }
@ -319,7 +319,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji,
return nil, gtserror.NewErrorBadRequest(err, err.Error()) return nil, gtserror.NewErrorBadRequest(err, err.Error())
} }
maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, *shortcode, "") maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, *shortcode, "")
if maybeExisting != nil { if maybeExisting != nil {
err := fmt.Errorf("emojiUpdateCopy: emoji %s could not be copied, emoji with shortcode %s already exists on this instance", emoji.ID, *shortcode) err := fmt.Errorf("emojiUpdateCopy: emoji %s could not be copied, emoji with shortcode %s already exists on this instance", emoji.ID, *shortcode)
return nil, gtserror.NewErrorConflict(err, err.Error()) return nil, gtserror.NewErrorConflict(err, err.Error())
@ -339,7 +339,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji,
newEmojiURI := uris.GenerateURIForEmoji(newEmojiID) newEmojiURI := uris.GenerateURIForEmoji(newEmojiID)
data := func(ctx context.Context) (reader io.ReadCloser, fileSize int64, err error) { data := func(ctx context.Context) (reader io.ReadCloser, fileSize int64, err error) {
rc, err := p.storage.GetStream(ctx, emoji.ImagePath) rc, err := p.state.Storage.GetStream(ctx, emoji.ImagePath)
return rc, int64(emoji.ImageFileSize), err return rc, int64(emoji.ImageFileSize), err
} }
@ -386,7 +386,7 @@ func (p *Processor) emojiUpdateDisable(ctx context.Context, emoji *gtsmodel.Emoj
emojiDisabled := true emojiDisabled := true
emoji.Disabled = &emojiDisabled emoji.Disabled = &emojiDisabled
updatedEmoji, err := p.db.UpdateEmoji(ctx, emoji, "updated_at", "disabled") updatedEmoji, err := p.state.DB.UpdateEmoji(ctx, emoji, "updated_at", "disabled")
if err != nil { if err != nil {
err = fmt.Errorf("emojiUpdateDisable: error updating emoji %s: %s", emoji.ID, err) err = fmt.Errorf("emojiUpdateDisable: error updating emoji %s: %s", emoji.ID, err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -443,7 +443,7 @@ func (p *Processor) emojiUpdateModify(ctx context.Context, emoji *gtsmodel.Emoji
} }
var err error var err error
updatedEmoji, err = p.db.UpdateEmoji(ctx, emoji, columns...) updatedEmoji, err = p.state.DB.UpdateEmoji(ctx, emoji, columns...)
if err != nil { if err != nil {
err = fmt.Errorf("emojiUpdateModify: error updating emoji %s: %s", emoji.ID, err) err = fmt.Errorf("emojiUpdateModify: error updating emoji %s: %s", emoji.ID, err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)

View file

@ -43,7 +43,7 @@ func (p *Processor) ReportsGet(
minID string, minID string,
limit int, limit int,
) (*apimodel.PageableResponse, gtserror.WithCode) { ) (*apimodel.PageableResponse, gtserror.WithCode) {
reports, err := p.db.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit) reports, err := p.state.DB.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return util.EmptyPageableResponse(), nil return util.EmptyPageableResponse(), nil
@ -95,7 +95,7 @@ func (p *Processor) ReportsGet(
// ReportGet returns one report, with the given ID. // ReportGet returns one report, with the given ID.
func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.AdminReport, gtserror.WithCode) { func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.AdminReport, gtserror.WithCode) {
report, err := p.db.GetReportByID(ctx, id) report, err := p.state.DB.GetReportByID(ctx, id)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
@ -113,7 +113,7 @@ func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id
// ReportResolve marks a report with the given id as resolved, and stores the provided actionTakenComment (if not null). // ReportResolve marks a report with the given id as resolved, and stores the provided actionTakenComment (if not null).
func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account, id string, actionTakenComment *string) (*apimodel.AdminReport, gtserror.WithCode) { func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account, id string, actionTakenComment *string) (*apimodel.AdminReport, gtserror.WithCode) {
report, err := p.db.GetReportByID(ctx, id) report, err := p.state.DB.GetReportByID(ctx, id)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
@ -134,7 +134,7 @@ func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "action_taken") columns = append(columns, "action_taken")
} }
updatedReport, err := p.db.UpdateReport(ctx, report, columns...) updatedReport, err := p.state.DB.UpdateReport(ctx, report, columns...)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View file

@ -62,7 +62,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
} }
// chuck it in the db // chuck it in the db
if err := p.db.Put(ctx, app); err != nil { if err := p.state.DB.Put(ctx, app); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -76,7 +76,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
} }
// chuck it in the db // chuck it in the db
if err := p.db.Put(ctx, oc); err != nil { if err := p.state.DB.Put(ctx, oc); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View file

@ -31,7 +31,7 @@ import (
) )
func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit) accounts, nextMaxID, prevMinID, err := p.state.DB.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// there are just no entries // there are just no entries

View file

@ -84,8 +84,8 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag
// scenario 2 -- get the requested page // scenario 2 -- get the requested page
// limit pages to 30 entries per page // limit pages to 30 entries per page
publicStatuses, err := p.db.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true) publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true)
if err != nil && err != db.ErrNoEntries { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -161,7 +161,7 @@ func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUsername
return nil, errWithCode return nil, errWithCode
} }
statuses, err := p.db.GetAccountPinnedStatuses(ctx, requestedAccount.ID) statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, requestedAccount.ID)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)

View file

@ -29,7 +29,7 @@ import (
) )
func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) { func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) {
requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
if err != nil { if err != nil {
errWithCode = gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) errWithCode = gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
return return
@ -46,7 +46,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string)
return return
} }
blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil { if err != nil {
errWithCode = gtserror.NewErrorInternalError(err) errWithCode = gtserror.NewErrorInternalError(err)
return return

View file

@ -32,7 +32,7 @@ func (p *Processor) EmojiGet(ctx context.Context, requestedEmojiID string) (inte
return nil, errWithCode return nil, errWithCode
} }
requestedEmoji, err := p.db.GetEmojiByID(ctx, requestedEmojiID) requestedEmoji, err := p.state.DB.GetEmojiByID(ctx, requestedEmojiID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting emoji with id %s: %s", requestedEmojiID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting emoji with id %s: %s", requestedEmojiID, err))
} }

View file

@ -19,25 +19,25 @@
package fedi package fedi
import ( import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/internal/visibility"
) )
type Processor struct { type Processor struct {
db db.DB state *state.State
federator federation.Federator federator federation.Federator
tc typeutils.TypeConverter tc typeutils.TypeConverter
filter visibility.Filter filter visibility.Filter
} }
// New returns a new fedi processor. // New returns a new fedi processor.
func New(db db.DB, tc typeutils.TypeConverter, federator federation.Federator) Processor { func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator) Processor {
return Processor{ return Processor{
db: db, state: state,
federator: federator, federator: federator,
tc: tc, tc: tc,
filter: visibility.NewFilter(db), filter: visibility.NewFilter(state.DB),
} }
} }

View file

@ -36,7 +36,7 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req
return nil, errWithCode return nil, errWithCode
} }
status, err := p.db.GetStatusByID(ctx, requestedStatusID) status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
} }
@ -74,7 +74,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
return nil, errWithCode return nil, errWithCode
} }
status, err := p.db.GetStatusByID(ctx, requestedStatusID) status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
} }
@ -125,7 +125,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
default: default:
// scenario 3 // scenario 3
// get immediate children // get immediate children
replies, err := p.db.GetStatusChildren(ctx, status, true, minID) replies, err := p.state.DB.GetStatusChildren(ctx, status, true, minID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View file

@ -34,7 +34,7 @@ import (
// before returning a JSON serializable interface to the caller. // before returning a JSON serializable interface to the caller.
func (p *Processor) UserGet(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) { func (p *Processor) UserGet(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
// Get the instance-local account the request is referring to. // Get the instance-local account the request is referring to.
requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
} }
@ -63,7 +63,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque
return nil, gtserror.NewErrorUnauthorized(err) return nil, gtserror.NewErrorUnauthorized(err)
} }
blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View file

@ -64,12 +64,12 @@ func (p *Processor) NodeInfoRelGet(ctx context.Context) (*apimodel.WellKnownResp
func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserror.WithCode) { func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserror.WithCode) {
host := config.GetHost() host := config.GetHost()
userCount, err := p.db.CountInstanceUsers(ctx, host) userCount, err := p.state.DB.CountInstanceUsers(ctx, host)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
postCount, err := p.db.CountInstanceStatuses(ctx, host) postCount, err := p.state.DB.CountInstanceStatuses(ctx, host)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -99,7 +99,7 @@ func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserr
// WebfingerGet handles the GET for a webfinger resource. Most commonly, it will be used for returning account lookups. // WebfingerGet handles the GET for a webfinger resource. Most commonly, it will be used for returning account lookups.
func (p *Processor) WebfingerGet(ctx context.Context, requestedUsername string) (*apimodel.WellKnownResponse, gtserror.WithCode) { func (p *Processor) WebfingerGet(ctx context.Context, requestedUsername string) (*apimodel.WellKnownResponse, gtserror.WithCode) {
// Get the local account the request is referring to. // Get the local account the request is referring to.
requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
} }

View file

@ -30,7 +30,7 @@ import (
) )
func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID) frs, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID)
if err != nil { if err != nil {
if err != db.ErrNoEntries { if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -40,7 +40,7 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
accts := []apimodel.Account{} accts := []apimodel.Account{}
for _, fr := range frs { for _, fr := range frs {
if fr.Account == nil { if fr.Account == nil {
frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID) frAcct, err := p.state.DB.GetAccountByID(ctx, fr.AccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -57,13 +57,13 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
} }
func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID) follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
} }
if follow.Account == nil { if follow.Account == nil {
followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID) followAccount, err := p.state.DB.GetAccountByID(ctx, follow.AccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -71,14 +71,14 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
} }
if follow.TargetAccount == nil { if follow.TargetAccount == nil {
followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID) followTargetAccount, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
follow.TargetAccount = followTargetAccount follow.TargetAccount = followTargetAccount
} }
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityAccept, APActivityType: ap.ActivityAccept,
GTSModel: follow, GTSModel: follow,
@ -86,7 +86,7 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
TargetAccount: follow.TargetAccount, TargetAccount: follow.TargetAccount,
}) })
gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -100,13 +100,13 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
} }
func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
followRequest, err := p.db.RejectFollowRequest(ctx, accountID, auth.Account.ID) followRequest, err := p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
} }
if followRequest.Account == nil { if followRequest.Account == nil {
a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -114,14 +114,14 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a
} }
if followRequest.TargetAccount == nil { if followRequest.TargetAccount == nil {
a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
followRequest.TargetAccount = a followRequest.TargetAccount = a
} }
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityFollow, APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityReject, APActivityType: ap.ActivityReject,
GTSModel: followRequest, GTSModel: followRequest,
@ -129,7 +129,7 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a
TargetAccount: followRequest.TargetAccount, TargetAccount: followRequest.TargetAccount,
}) })
gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View file

@ -143,7 +143,7 @@ func (p *Processor) processCreateAccountFromClientAPI(ctx context.Context, clien
} }
// get the user this account belongs to // get the user this account belongs to
user, err := p.db.GetUserByAccountID(ctx, account.ID) user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)
if err != nil { if err != nil {
return err return err
} }
@ -293,7 +293,7 @@ func (p *Processor) processUndoAnnounceFromClientAPI(ctx context.Context, client
return errors.New("undo was not parseable as *gtsmodel.Status") return errors.New("undo was not parseable as *gtsmodel.Status")
} }
if err := p.db.DeleteStatusByID(ctx, boost.ID); err != nil { if err := p.state.DB.DeleteStatusByID(ctx, boost.ID); err != nil {
return err return err
} }
@ -422,7 +422,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status)
} }
if status.Account == nil { if status.Account == nil {
statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("federateStatus: error fetching status author account: %s", err) return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
} }
@ -455,7 +455,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status)
func (p *Processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error { func (p *Processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil { if status.Account == nil {
statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err) return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err)
} }
@ -642,7 +642,7 @@ func (p *Processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Stat
func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow) error { func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow) error {
if follow.Account == nil { if follow.Account == nil {
a, err := p.db.GetAccountByID(ctx, follow.AccountID) a, err := p.state.DB.GetAccountByID(ctx, follow.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -651,7 +651,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts
originAccount := follow.Account originAccount := follow.Account
if follow.TargetAccount == nil { if follow.TargetAccount == nil {
a, err := p.db.GetAccountByID(ctx, follow.TargetAccountID) a, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil { if err != nil {
return err return err
} }
@ -715,7 +715,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts
func (p *Processor) federateRejectFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error { func (p *Processor) federateRejectFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {
if followRequest.Account == nil { if followRequest.Account == nil {
a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -724,7 +724,7 @@ func (p *Processor) federateRejectFollowRequest(ctx context.Context, followReque
originAccount := followRequest.Account originAccount := followRequest.Account
if followRequest.TargetAccount == nil { if followRequest.TargetAccount == nil {
a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil { if err != nil {
return err return err
} }
@ -844,7 +844,7 @@ func (p *Processor) federateAccountUpdate(ctx context.Context, updatedAccount *g
func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error { func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil { if block.Account == nil {
blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("federateBlock: error getting block account from database: %s", err) return fmt.Errorf("federateBlock: error getting block account from database: %s", err)
} }
@ -852,7 +852,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er
} }
if block.TargetAccount == nil { if block.TargetAccount == nil {
blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID)
if err != nil { if err != nil {
return fmt.Errorf("federateBlock: error getting block target account from database: %s", err) return fmt.Errorf("federateBlock: error getting block target account from database: %s", err)
} }
@ -880,7 +880,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er
func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error { func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil { if block.Account == nil {
blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("federateUnblock: error getting block account from database: %s", err) return fmt.Errorf("federateUnblock: error getting block account from database: %s", err)
} }
@ -888,7 +888,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block)
} }
if block.TargetAccount == nil { if block.TargetAccount == nil {
blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID)
if err != nil { if err != nil {
return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err) return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err)
} }
@ -934,7 +934,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block)
func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report) error { func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report) error {
if report.TargetAccount == nil { if report.TargetAccount == nil {
reportTargetAccount, err := p.db.GetAccountByID(ctx, report.TargetAccountID) reportTargetAccount, err := p.state.DB.GetAccountByID(ctx, report.TargetAccountID)
if err != nil { if err != nil {
return fmt.Errorf("federateReport: error getting report target account from database: %w", err) return fmt.Errorf("federateReport: error getting report target account from database: %w", err)
} }
@ -942,7 +942,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report)
} }
if len(report.StatusIDs) > 0 && len(report.Statuses) == 0 { if len(report.StatusIDs) > 0 && len(report.Statuses) == 0 {
statuses, err := p.db.GetStatuses(ctx, report.StatusIDs) statuses, err := p.state.DB.GetStatuses(ctx, report.StatusIDs)
if err != nil { if err != nil {
return fmt.Errorf("federateReport: error getting report statuses from database: %w", err) return fmt.Errorf("federateReport: error getting report statuses from database: %w", err)
} }
@ -966,7 +966,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report)
// deliver the flag using the outbox of the // deliver the flag using the outbox of the
// instance account to anonymize the report // instance account to anonymize the report
instanceAccount, err := p.db.GetInstanceAccount(ctx, "") instanceAccount, err := p.state.DB.GetInstanceAccount(ctx, "")
if err != nil { if err != nil {
return fmt.Errorf("federateReport: error getting instance account: %w", err) return fmt.Errorf("federateReport: error getting instance account: %w", err)
} }

View file

@ -38,7 +38,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
if status.Mentions == nil { if status.Mentions == nil {
// there are mentions but they're not fully populated on the status yet so do this // there are mentions but they're not fully populated on the status yet so do this
menchies, err := p.db.GetMentions(ctx, status.MentionIDs) menchies, err := p.state.DB.GetMentions(ctx, status.MentionIDs)
if err != nil { if err != nil {
return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err) return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err)
} }
@ -49,7 +49,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
for _, m := range status.Mentions { for _, m := range status.Mentions {
// make sure this is a local account, otherwise we don't need to create a notification for it // make sure this is a local account, otherwise we don't need to create a notification for it
if m.TargetAccount == nil { if m.TargetAccount == nil {
a, err := p.db.GetAccountByID(ctx, m.TargetAccountID) a, err := p.state.DB.GetAccountByID(ctx, m.TargetAccountID)
if err != nil { if err != nil {
// we don't have the account or there's been an error // we don't have the account or there's been an error
return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err) return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err)
@ -62,7 +62,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
} }
// make sure a notif doesn't already exist for this mention // make sure a notif doesn't already exist for this mention
if err := p.db.GetWhere(ctx, []db.Where{ if err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationMention}, {Key: "notification_type", Value: gtsmodel.NotificationMention},
{Key: "target_account_id", Value: m.TargetAccountID}, {Key: "target_account_id", Value: m.TargetAccountID},
{Key: "origin_account_id", Value: m.OriginAccountID}, {Key: "origin_account_id", Value: m.OriginAccountID},
@ -87,7 +87,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
Status: status, Status: status,
} }
if err := p.db.Put(ctx, notif); err != nil { if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyStatus: error putting notification in database: %s", err) return fmt.Errorf("notifyStatus: error putting notification in database: %s", err)
} }
@ -108,7 +108,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error { func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {
// make sure we have the target account pinned on the follow request // make sure we have the target account pinned on the follow request
if followRequest.TargetAccount == nil { if followRequest.TargetAccount == nil {
a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil { if err != nil {
return err return err
} }
@ -129,7 +129,7 @@ func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsm
OriginAccountID: followRequest.AccountID, OriginAccountID: followRequest.AccountID,
} }
if err := p.db.Put(ctx, notif); err != nil { if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err) return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err)
} }
@ -153,7 +153,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
} }
// first remove the follow request notification // first remove the follow request notification
if err := p.db.DeleteWhere(ctx, []db.Where{ if err := p.state.DB.DeleteWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationFollowRequest}, {Key: "notification_type", Value: gtsmodel.NotificationFollowRequest},
{Key: "target_account_id", Value: follow.TargetAccountID}, {Key: "target_account_id", Value: follow.TargetAccountID},
{Key: "origin_account_id", Value: follow.AccountID}, {Key: "origin_account_id", Value: follow.AccountID},
@ -170,7 +170,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
OriginAccountID: follow.AccountID, OriginAccountID: follow.AccountID,
OriginAccount: follow.Account, OriginAccount: follow.Account,
} }
if err := p.db.Put(ctx, notif); err != nil { if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFollow: error putting notification in database: %s", err) return fmt.Errorf("notifyFollow: error putting notification in database: %s", err)
} }
@ -194,7 +194,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e
} }
if fave.TargetAccount == nil { if fave.TargetAccount == nil {
a, err := p.db.GetAccountByID(ctx, fave.TargetAccountID) a, err := p.state.DB.GetAccountByID(ctx, fave.TargetAccountID)
if err != nil { if err != nil {
return err return err
} }
@ -218,7 +218,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e
Status: fave.Status, Status: fave.Status,
} }
if err := p.db.Put(ctx, notif); err != nil { if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyFave: error putting notification in database: %s", err) return fmt.Errorf("notifyFave: error putting notification in database: %s", err)
} }
@ -242,7 +242,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
} }
if status.BoostOf == nil { if status.BoostOf == nil {
boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID) boostedStatus, err := p.state.DB.GetStatusByID(ctx, status.BoostOfID)
if err != nil { if err != nil {
return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err) return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err)
} }
@ -250,7 +250,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
} }
if status.BoostOfAccount == nil { if status.BoostOfAccount == nil {
boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID) boostedAcct, err := p.state.DB.GetAccountByID(ctx, status.BoostOfAccountID)
if err != nil { if err != nil {
return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err) return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err)
} }
@ -269,7 +269,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
} }
// make sure a notif doesn't already exist for this announce // make sure a notif doesn't already exist for this announce
err := p.db.GetWhere(ctx, []db.Where{ err := p.state.DB.GetWhere(ctx, []db.Where{
{Key: "notification_type", Value: gtsmodel.NotificationReblog}, {Key: "notification_type", Value: gtsmodel.NotificationReblog},
{Key: "target_account_id", Value: status.BoostOfAccountID}, {Key: "target_account_id", Value: status.BoostOfAccountID},
{Key: "origin_account_id", Value: status.AccountID}, {Key: "origin_account_id", Value: status.AccountID},
@ -292,7 +292,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
Status: status, Status: status,
} }
if err := p.db.Put(ctx, notif); err != nil { if err := p.state.DB.Put(ctx, notif); err != nil {
return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err) return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err)
} }
@ -314,7 +314,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error { func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error {
// make sure the author account is pinned onto the status // make sure the author account is pinned onto the status
if status.Account == nil { if status.Account == nil {
a, err := p.db.GetAccountByID(ctx, status.AccountID) a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err) return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err)
} }
@ -322,7 +322,7 @@ func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status)
} }
// get local followers of the account that posted the status // get local followers of the account that posted the status
follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true) follows, err := p.state.DB.GetAccountFollowedBy(ctx, status.AccountID, true)
if err != nil { if err != nil {
return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err) return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err)
} }
@ -374,7 +374,7 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, status *gtsmod
defer wg.Done() defer wg.Done()
// get the timeline owner account // get the timeline owner account
timelineAccount, err := p.db.GetAccountByID(ctx, accountID) timelineAccount, err := p.state.DB.GetAccountByID(ctx, accountID)
if err != nil { if err != nil {
errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err) errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err)
return return
@ -446,28 +446,28 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
// delete all mention entries generated by this status // delete all mention entries generated by this status
for _, m := range statusToDelete.MentionIDs { for _, m := range statusToDelete.MentionIDs {
if err := p.db.DeleteByID(ctx, m, &gtsmodel.Mention{}); err != nil { if err := p.state.DB.DeleteByID(ctx, m, &gtsmodel.Mention{}); err != nil {
return err return err
} }
} }
// delete all notification entries generated by this status // delete all notification entries generated by this status
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
return err return err
} }
// delete all bookmarks that point to this status // delete all bookmarks that point to this status
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
return err return err
} }
// delete all boosts for this status + remove them from timelines // delete all boosts for this status + remove them from timelines
if boosts, err := p.db.GetStatusReblogs(ctx, statusToDelete); err == nil { if boosts, err := p.state.DB.GetStatusReblogs(ctx, statusToDelete); err == nil {
for _, b := range boosts { for _, b := range boosts {
if err := p.deleteStatusFromTimelines(ctx, b); err != nil { if err := p.deleteStatusFromTimelines(ctx, b); err != nil {
return err return err
} }
if err := p.db.DeleteStatusByID(ctx, b.ID); err != nil { if err := p.state.DB.DeleteStatusByID(ctx, b.ID); err != nil {
return err return err
} }
} }
@ -479,7 +479,7 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
} }
// delete the status itself // delete the status itself
if err := p.db.DeleteStatusByID(ctx, statusToDelete.ID); err != nil { if err := p.state.DB.DeleteStatusByID(ctx, statusToDelete.ID); err != nil {
return err return err
} }

View file

@ -139,7 +139,7 @@ func (p *Processor) processCreateStatusFromFederator(ctx context.Context, federa
// make sure the account is pinned // make sure the account is pinned
if status.Account == nil { if status.Account == nil {
a, err := p.db.GetAccountByID(ctx, status.AccountID) a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -185,7 +185,7 @@ func (p *Processor) processCreateFaveFromFederator(ctx context.Context, federato
// make sure the account is pinned // make sure the account is pinned
if incomingFave.Account == nil { if incomingFave.Account == nil {
a, err := p.db.GetAccountByID(ctx, incomingFave.AccountID) a, err := p.state.DB.GetAccountByID(ctx, incomingFave.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -227,7 +227,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
// make sure the account is pinned // make sure the account is pinned
if followRequest.Account == nil { if followRequest.Account == nil {
a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -254,7 +254,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
} }
if followRequest.TargetAccount == nil { if followRequest.TargetAccount == nil {
a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil { if err != nil {
return err return err
} }
@ -267,7 +267,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
} }
// if the target account isn't locked, we should already accept the follow and notify about the new follower instead // if the target account isn't locked, we should already accept the follow and notify about the new follower instead
follow, err := p.db.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID) follow, err := p.state.DB.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID)
if err != nil { if err != nil {
return err return err
} }
@ -288,7 +288,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede
// make sure the account is pinned // make sure the account is pinned
if incomingAnnounce.Account == nil { if incomingAnnounce.Account == nil {
a, err := p.db.GetAccountByID(ctx, incomingAnnounce.AccountID) a, err := p.state.DB.GetAccountByID(ctx, incomingAnnounce.AccountID)
if err != nil { if err != nil {
return err return err
} }
@ -324,7 +324,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede
} }
incomingAnnounce.ID = incomingAnnounceID incomingAnnounce.ID = incomingAnnounceID
if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil { if err := p.state.DB.PutStatus(ctx, incomingAnnounce); err != nil {
return fmt.Errorf("error adding dereferenced announce to the db: %s", err) return fmt.Errorf("error adding dereferenced announce to the db: %s", err)
} }

View file

@ -344,7 +344,6 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {
suite.NoError(err) suite.NoError(err)
// now they are mufos! // now they are mufos!
err = suite.processor.ProcessFromFederator(ctx, messages.FromFederator{ err = suite.processor.ProcessFromFederator(ctx, messages.FromFederator{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,

View file

@ -35,7 +35,7 @@ import (
func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) { func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) {
i := &gtsmodel.Instance{} i := &gtsmodel.Instance{}
if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil { if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil {
return nil, err return nil, err
} }
return i, nil return i, nil
@ -73,7 +73,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
domains := []*apimodel.Domain{} domains := []*apimodel.Domain{}
if includeOpen { if includeOpen {
instances, err := p.db.GetInstancePeers(ctx, false) instances, err := p.state.DB.GetInstancePeers(ctx, false)
if err != nil && err != db.ErrNoEntries { if err != nil && err != db.ErrNoEntries {
err = fmt.Errorf("error selecting instance peers: %s", err) err = fmt.Errorf("error selecting instance peers: %s", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -87,7 +87,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
if includeSuspended { if includeSuspended {
domainBlocks := []*gtsmodel.DomainBlock{} domainBlocks := []*gtsmodel.DomainBlock{}
if err := p.db.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries { if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -124,12 +124,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
// fetch the instance entry from the db for processing // fetch the instance entry from the db for processing
i := &gtsmodel.Instance{} i := &gtsmodel.Instance{}
host := config.GetHost() host := config.GetHost()
if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil { if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err))
} }
// fetch the instance account from the db for processing // fetch the instance account from the db for processing
ia, err := p.db.GetInstanceAccount(ctx, "") ia, err := p.state.DB.GetInstanceAccount(ctx, "")
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", host, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", host, err))
} }
@ -148,12 +148,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
// validate & update site contact account if it's set on the form // validate & update site contact account if it's set on the form
if form.ContactUsername != nil { if form.ContactUsername != nil {
// make sure the account with the given username exists in the db // make sure the account with the given username exists in the db
contactAccount, err := p.db.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "") contactAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "")
if err != nil { if err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername)) return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))
} }
// make sure it has a user associated with it // make sure it has a user associated with it
contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID) contactUser, err := p.state.DB.GetUserByAccountID(ctx, contactAccount.ID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername)) return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))
} }
@ -233,7 +233,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
} else if form.AvatarDescription != nil && ia.AvatarMediaAttachment != nil { } else if form.AvatarDescription != nil && ia.AvatarMediaAttachment != nil {
// process just the description for the existing avatar // process just the description for the existing avatar
ia.AvatarMediaAttachment.Description = *form.AvatarDescription ia.AvatarMediaAttachment.Description = *form.AvatarDescription
if err := p.db.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil { if err := p.state.DB.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance avatar description: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance avatar description: %s", err))
} }
} }
@ -252,13 +252,13 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
if updateInstanceAccount { if updateInstanceAccount {
// if either avatar or header is updated, we need // if either avatar or header is updated, we need
// to update the instance account that stores them // to update the instance account that stores them
if err := p.db.UpdateAccount(ctx, ia); err != nil { if err := p.state.DB.UpdateAccount(ctx, ia); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err))
} }
} }
if len(updatingColumns) != 0 { if len(updatingColumns) != 0 {
if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err))
} }
} }

View file

@ -13,7 +13,7 @@ import (
// Delete deletes the media attachment with the given ID, including all files pertaining to that attachment. // Delete deletes the media attachment with the given ID, including all files pertaining to that attachment.
func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode { func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode {
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// attachment already gone // attachment already gone
@ -27,20 +27,20 @@ func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserr
// delete the thumbnail from storage // delete the thumbnail from storage
if attachment.Thumbnail.Path != "" { if attachment.Thumbnail.Path != "" {
if err := p.storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) { if err := p.state.Storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err)) errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err))
} }
} }
// delete the file from storage // delete the file from storage
if attachment.File.Path != "" { if attachment.File.Path != "" {
if err := p.storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) { if err := p.state.Storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err)) errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err))
} }
} }
// delete the attachment // delete the attachment
if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) { if err := p.state.DB.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) {
errs = append(errs, fmt.Sprintf("remove attachment: %s", err)) errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
} }

View file

@ -31,7 +31,7 @@ import (
// GetCustomEmojis returns a list of all useable local custom emojis stored on this instance. // GetCustomEmojis returns a list of all useable local custom emojis stored on this instance.
// 'useable' in this context means visible and picker, and not disabled. // 'useable' in this context means visible and picker, and not disabled.
func (p *Processor) GetCustomEmojis(ctx context.Context) ([]*apimodel.Emoji, gtserror.WithCode) { func (p *Processor) GetCustomEmojis(ctx context.Context) ([]*apimodel.Emoji, gtserror.WithCode) {
emojis, err := p.db.GetUseableEmojis(ctx) emojis, err := p.state.DB.GetUseableEmojis(ctx)
if err != nil { if err != nil {
if err != db.ErrNoEntries { if err != db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error retrieving custom emojis: %s", err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error retrieving custom emojis: %s", err))

View file

@ -54,7 +54,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc
owningAccountID := form.AccountID owningAccountID := form.AccountID
// get the account that owns the media and make sure it's not suspended // get the account that owns the media and make sure it's not suspended
owningAccount, err := p.db.GetAccountByID(ctx, owningAccountID) owningAccount, err := p.state.DB.GetAccountByID(ctx, owningAccountID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", owningAccountID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", owningAccountID, err))
} }
@ -64,7 +64,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc
// make sure the requesting account and the media account don't block each other // make sure the requesting account and the media account don't block each other
if requestingAccount != nil { if requestingAccount != nil {
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true) blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", owningAccountID, requestingAccount.ID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", owningAccountID, requestingAccount.ID, err))
} }
@ -117,7 +117,7 @@ func parseSize(s string) (media.Size, error) {
func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount *gtsmodel.Account, wantedMediaID string, owningAccountID string, mediaSize media.Size) (*apimodel.Content, gtserror.WithCode) { func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount *gtsmodel.Account, wantedMediaID string, owningAccountID string, mediaSize media.Size) (*apimodel.Content, gtserror.WithCode) {
// retrieve attachment from the database and do basic checks on it // retrieve attachment from the database and do basic checks on it
a, err := p.db.GetAttachmentByID(ctx, wantedMediaID) a, err := p.state.DB.GetAttachmentByID(ctx, wantedMediaID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err))
} }
@ -209,7 +209,7 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning
// so this is more reliable than using full size url // so this is more reliable than using full size url
imageStaticURL := uris.GenerateURIForAttachment(owningAccountID, string(media.TypeEmoji), string(media.SizeStatic), fileName, "png") imageStaticURL := uris.GenerateURIForAttachment(owningAccountID, string(media.TypeEmoji), string(media.SizeStatic), fileName, "png")
e, err := p.db.GetEmojiByStaticURL(ctx, imageStaticURL) e, err := p.state.DB.GetEmojiByStaticURL(ctx, imageStaticURL)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("emoji %s could not be taken from the db: %s", fileName, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("emoji %s could not be taken from the db: %s", fileName, err))
} }
@ -237,12 +237,12 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning
func (p *Processor) retrieveFromStorage(ctx context.Context, storagePath string, content *apimodel.Content) (*apimodel.Content, gtserror.WithCode) { func (p *Processor) retrieveFromStorage(ctx context.Context, storagePath string, content *apimodel.Content) (*apimodel.Content, gtserror.WithCode) {
// If running on S3 storage with proxying disabled then // If running on S3 storage with proxying disabled then
// just fetch a pre-signed URL instead of serving the content. // just fetch a pre-signed URL instead of serving the content.
if url := p.storage.URL(ctx, storagePath); url != nil { if url := p.state.Storage.URL(ctx, storagePath); url != nil {
content.URL = url content.URL = url
return content, nil return content, nil
} }
reader, err := p.storage.GetStream(ctx, storagePath) reader, err := p.state.Storage.GetStream(ctx, storagePath)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error retrieving from storage: %s", err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("error retrieving from storage: %s", err))
} }

View file

@ -30,7 +30,7 @@ import (
) )
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// attachment doesn't exist // attachment doesn't exist

View file

@ -19,28 +19,25 @@
package media package media
import ( import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
) )
type Processor struct { type Processor struct {
state *state.State
tc typeutils.TypeConverter tc typeutils.TypeConverter
mediaManager media.Manager mediaManager media.Manager
transportController transport.Controller transportController transport.Controller
storage *storage.Driver
db db.DB
} }
// New returns a new media processor. // New returns a new media processor.
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver) Processor { func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor {
return Processor{ return Processor{
state: state,
tc: tc, tc: tc,
mediaManager: mediaManager, mediaManager: mediaManager,
transportController: transportController, transportController: transportController,
storage: storage,
db: db,
} }
} }

View file

@ -20,12 +20,11 @@ package media_test
import ( import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media" mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
@ -38,6 +37,7 @@ type MediaStandardTestSuite struct {
db db.DB db db.DB
tc typeutils.TypeConverter tc typeutils.TypeConverter
storage *storage.Driver storage *storage.Driver
state state.State
mediaManager media.Manager mediaManager media.Manager
transportController transport.Controller transportController transport.Controller
@ -67,15 +67,19 @@ func (suite *MediaStandardTestSuite) SetupSuite() {
} }
func (suite *MediaStandardTestSuite) SetupTest() { func (suite *MediaStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db) suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) suite.state.Storage = suite.storage
suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1)) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.mediaProcessor = mediaprocessing.New(suite.db, suite.tc, suite.mediaManager, suite.transportController, suite.storage) suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
suite.mediaProcessor = mediaprocessing.New(&suite.state, suite.tc, suite.mediaManager, suite.transportController)
testrig.StandardDBSetup(suite.db, nil) testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
} }

View file

@ -33,7 +33,7 @@ import (
// Unattach unattaches the media attachment with the given ID from any statuses it was attached to, making it available // Unattach unattaches the media attachment with the given ID from any statuses it was attached to, making it available
// for reattachment again. // for reattachment again.
func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
@ -49,7 +49,7 @@ func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, med
attachment.UpdatedAt = time.Now() attachment.UpdatedAt = time.Now()
attachment.StatusID = "" attachment.StatusID = ""
if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err))
} }

View file

@ -32,7 +32,7 @@ import (
// Update updates a media attachment with the given id, using the provided form parameters. // Update updates a media attachment with the given id, using the provided form parameters.
func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
// attachment doesn't exist // attachment doesn't exist
@ -62,7 +62,7 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, media
updatingColumns = append(updatingColumns, "focus_x", "focus_y") updatingColumns = append(updatingColumns, "focus_x", "focus_y")
} }
if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err))
} }

View file

@ -29,7 +29,7 @@ import (
) )
func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, excludeTypes []string, limit int, maxID string, sinceID string) (*apimodel.PageableResponse, gtserror.WithCode) { func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, excludeTypes []string, limit int, maxID string, sinceID string) (*apimodel.PageableResponse, gtserror.WithCode) {
notifs, err := p.db.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID) notifs, err := p.state.DB.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
@ -72,7 +72,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ex
} }
func (p *Processor) NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode { func (p *Processor) NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode {
err := p.db.ClearNotifications(ctx, authed.Account.ID) err := p.state.DB.ClearNotifications(ctx, authed.Account.ID)
if err != nil { if err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }

View file

@ -19,10 +19,11 @@
package processing package processing
import ( import (
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "context"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/log"
mm "github.com/superseriousbusiness/gotosocial/internal/media" mm "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -34,23 +35,19 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/status" "github.com/superseriousbusiness/gotosocial/internal/processing/status"
"github.com/superseriousbusiness/gotosocial/internal/processing/stream" "github.com/superseriousbusiness/gotosocial/internal/processing/stream"
"github.com/superseriousbusiness/gotosocial/internal/processing/user" "github.com/superseriousbusiness/gotosocial/internal/processing/user"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/internal/visibility"
) )
type Processor struct { type Processor struct {
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
fedWorker *concurrency.WorkerPool[messages.FromFederator]
federator federation.Federator federator federation.Federator
tc typeutils.TypeConverter tc typeutils.TypeConverter
oauthServer oauth.Server oauthServer oauth.Server
mediaManager mm.Manager mediaManager mm.Manager
storage *storage.Driver
statusTimelines timeline.Manager statusTimelines timeline.Manager
db db.DB state *state.State
filter visibility.Filter filter visibility.Filter
/* /*
@ -105,76 +102,65 @@ func NewProcessor(
federator federation.Federator, federator federation.Federator,
oauthServer oauth.Server, oauthServer oauth.Server,
mediaManager mm.Manager, mediaManager mm.Manager,
storage *storage.Driver, state *state.State,
db db.DB,
emailSender email.Sender, emailSender email.Sender,
clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
fedWorker *concurrency.WorkerPool[messages.FromFederator],
) *Processor { ) *Processor {
parseMentionFunc := GetParseMentionFunc(db, federator) parseMentionFunc := GetParseMentionFunc(state.DB, federator)
filter := visibility.NewFilter(db) filter := visibility.NewFilter(state.DB)
return &Processor{ processor := &Processor{
clientWorker: clientWorker, federator: federator,
fedWorker: fedWorker, tc: tc,
oauthServer: oauthServer,
federator: federator, mediaManager: mediaManager,
tc: tc, statusTimelines: timeline.NewManager(
oauthServer: oauthServer, StatusGrabFunction(state.DB),
mediaManager: mediaManager, StatusFilterFunction(state.DB, filter),
storage: storage, StatusPrepareFunction(state.DB, tc),
statusTimelines: timeline.NewManager(StatusGrabFunction(db), StatusFilterFunction(db, filter), StatusPrepareFunction(db, tc), StatusSkipInsertFunction()), StatusSkipInsertFunction(),
db: db, ),
filter: filter, state: state,
filter: filter,
// sub processors
account: account.New(db, tc, mediaManager, oauthServer, clientWorker, federator, parseMentionFunc),
admin: admin.New(db, tc, mediaManager, federator.TransportController(), storage, clientWorker),
fedi: fedi.New(db, tc, federator),
media: media.New(db, tc, mediaManager, federator.TransportController(), storage),
report: report.New(db, tc, clientWorker),
status: status.New(db, tc, clientWorker, parseMentionFunc),
stream: stream.New(db, oauthServer),
user: user.New(db, emailSender),
} }
// sub processors
processor.account = account.New(state, tc, mediaManager, oauthServer, federator, parseMentionFunc)
processor.admin = admin.New(state, tc, mediaManager, federator.TransportController())
processor.fedi = fedi.New(state, tc, federator)
processor.media = media.New(state, tc, mediaManager, federator.TransportController())
processor.report = report.New(state, tc)
processor.status = status.New(state, tc, parseMentionFunc)
processor.stream = stream.New(state, oauthServer)
processor.user = user.New(state, emailSender)
return processor
} }
// Start starts the Processor, reading from its channels and passing messages back and forth. func (p *Processor) EnqueueClientAPI(ctx context.Context, msg messages.FromClientAPI) {
log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing client API")
_ = p.state.Workers.ClientAPI.MustEnqueueCtx(ctx, func(ctx context.Context) {
if err := p.ProcessFromClientAPI(ctx, msg); err != nil {
log.Errorf(ctx, "error processing client API message: %v", err)
}
})
}
func (p *Processor) EnqueueFederator(ctx context.Context, msg messages.FromFederator) {
log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing federator")
_ = p.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) {
if err := p.ProcessFromFederator(ctx, msg); err != nil {
log.Errorf(ctx, "error processing federator message: %v", err)
}
})
}
// Start starts the Processor.
func (p *Processor) Start() error { func (p *Processor) Start() error {
// Setup and start the client API worker pool return p.statusTimelines.Start()
p.clientWorker.SetProcessor(p.ProcessFromClientAPI)
if err := p.clientWorker.Start(); err != nil {
return err
}
// Setup and start the federator worker pool
p.fedWorker.SetProcessor(p.ProcessFromFederator)
if err := p.fedWorker.Start(); err != nil {
return err
}
// Start status timelines
if err := p.statusTimelines.Start(); err != nil {
return err
}
return nil
} }
// Stop stops the processor cleanly, finishing handling any remaining messages before closing down. // Stop stops the processor cleanly.
func (p *Processor) Stop() error { func (p *Processor) Stop() error {
if err := p.clientWorker.Stop(); err != nil { return p.statusTimelines.Stop()
return err
}
if err := p.fedWorker.Stop(); err != nil {
return err
}
if err := p.statusTimelines.Stop(); err != nil {
return err
}
return nil
} }

View file

@ -20,15 +20,14 @@ package processing_test
import ( import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/email"
"github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
@ -40,6 +39,7 @@ type ProcessingStandardTestSuite struct {
suite.Suite suite.Suite
db db.DB db db.DB
storage *storage.Driver storage *storage.Driver
state state.State
mediaManager media.Manager mediaManager media.Manager
typeconverter typeutils.TypeConverter typeconverter typeutils.TypeConverter
httpClient *testrig.MockHTTPClient httpClient *testrig.MockHTTPClient
@ -86,25 +86,29 @@ func (suite *ProcessingStandardTestSuite) SetupSuite() {
} }
func (suite *ProcessingStandardTestSuite) SetupTest() { func (suite *ProcessingStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.InitTestConfig() testrig.InitTestConfig()
testrig.InitTestLog() testrig.InitTestLog()
suite.db = testrig.NewTestDB() suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.testActivities = testrig.NewTestActivities(suite.testAccounts) suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
suite.storage = testrig.NewInMemoryStorage() suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.typeconverter = testrig.NewTestTypeConverter(suite.db) suite.typeconverter = testrig.NewTestTypeConverter(suite.db)
suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media") suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media")
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient)
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
suite.transportController = testrig.NewTestTransportController(suite.httpClient, suite.db, fedWorker)
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker)
suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.emailSender = testrig.NewEmailSender("../../web/template/", nil) suite.emailSender = testrig.NewEmailSender("../../web/template/", nil)
suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, &suite.state, suite.emailSender)
suite.state.Workers.EnqueueClientAPI = suite.processor.EnqueueClientAPI
suite.state.Workers.EnqueueFederator = suite.processor.EnqueueFederator
testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../testrig/media") testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
@ -119,4 +123,5 @@ func (suite *ProcessingStandardTestSuite) TearDownTest() {
if err := suite.processor.Stop(); err != nil { if err := suite.processor.Stop(); err != nil {
panic(err) panic(err)
} }
testrig.StopWorkers(&suite.state)
} }

View file

@ -41,7 +41,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
} }
// validate + fetch target account // validate + fetch target account
targetAccount, err := p.db.GetAccountByID(ctx, form.AccountID) targetAccount, err := p.state.DB.GetAccountByID(ctx, form.AccountID)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNoEntries) { if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("account with ID %s does not exist", form.AccountID) err = fmt.Errorf("account with ID %s does not exist", form.AccountID)
@ -52,7 +52,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
} }
// fetch statuses by IDs given in the report form (noop if no statuses given) // fetch statuses by IDs given in the report form (noop if no statuses given)
statuses, err := p.db.GetStatuses(ctx, form.StatusIDs) statuses, err := p.state.DB.GetStatuses(ctx, form.StatusIDs)
if err != nil { if err != nil {
err = fmt.Errorf("db error fetching report target statuses: %w", err) err = fmt.Errorf("db error fetching report target statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -79,11 +79,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
Forwarded: &form.Forward, Forwarded: &form.Forward,
} }
if err := p.db.PutReport(ctx, report); err != nil { if err := p.state.DB.PutReport(ctx, report); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectProfile, APObjectType: ap.ObjectProfile,
APActivityType: ap.ActivityFlag, APActivityType: ap.ActivityFlag,
GTSModel: report, GTSModel: report,

View file

@ -32,7 +32,7 @@ import (
// Get returns the user view of a moderation report, with the given id. // Get returns the user view of a moderation report, with the given id.
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.Report, gtserror.WithCode) { func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.Report, gtserror.WithCode) {
report, err := p.db.GetReportByID(ctx, id) report, err := p.state.DB.GetReportByID(ctx, id)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
@ -64,7 +64,7 @@ func (p *Processor) GetMultiple(
minID string, minID string,
limit int, limit int,
) (*apimodel.PageableResponse, gtserror.WithCode) { ) (*apimodel.PageableResponse, gtserror.WithCode) {
reports, err := p.db.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit) reports, err := p.state.DB.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit)
if err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return util.EmptyPageableResponse(), nil return util.EmptyPageableResponse(), nil

View file

@ -19,22 +19,18 @@
package report package report
import ( import (
"github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
) )
type Processor struct { type Processor struct {
db db.DB state *state.State
tc typeutils.TypeConverter tc typeutils.TypeConverter
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
} }
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor { func New(state *state.State, tc typeutils.TypeConverter) Processor {
return Processor{ return Processor{
tc: tc, state: state,
db: db, tc: tc,
clientWorker: clientWorker,
} }
} }

View file

@ -88,7 +88,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
if username, domain, err := util.ExtractNamestringParts(maybeNamestring); err == nil { if username, domain, err := util.ExtractNamestringParts(maybeNamestring); err == nil {
l.Trace("search term is a mention, looking it up...") l.Trace("search term is a mention, looking it up...")
blocked, err := p.db.IsDomainBlocked(ctx, domain) blocked, err := p.state.DB.IsDomainBlocked(ctx, domain)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err))
} }
@ -120,7 +120,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
if uri, err := url.Parse(query); err == nil { if uri, err := url.Parse(query); err == nil {
if uri.Scheme == "https" || uri.Scheme == "http" { if uri.Scheme == "https" || uri.Scheme == "http" {
l.Trace("search term is a uri, looking it up...") l.Trace("search term is a uri, looking it up...")
blocked, err := p.db.IsURIBlocked(ctx, uri) blocked, err := p.state.DB.IsURIBlocked(ctx, uri)
if err != nil { if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err))
} }
@ -178,7 +178,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
*/ */
for _, foundAccount := range foundAccounts { for _, foundAccount := range foundAccounts {
// make sure there's no block in either direction between the account and the requester // make sure there's no block in either direction between the account and the requester
blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true) blocked, err := p.state.DB.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true)
if err != nil { if err != nil {
err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err) err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -246,14 +246,14 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth,
) )
// Search the database for existing account with ID URI. // Search the database for existing account with ID URI.
account, err = p.db.GetAccountByURI(ctx, uriStr) account, err = p.state.DB.GetAccountByURI(ctx, uriStr)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err) return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err)
} }
if account == nil { if account == nil {
// Else, search the database for existing by ID URL. // Else, search the database for existing by ID URL.
account, err = p.db.GetAccountByURL(ctx, uriStr) account, err = p.state.DB.GetAccountByURL(ctx, uriStr)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err) return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err)
@ -281,7 +281,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o
} }
// Search the database for existing account with USERNAME@DOMAIN // Search the database for existing account with USERNAME@DOMAIN
account, err := p.db.GetAccountByUsernameDomain(ctx, username, domain) account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, domain)
if err != nil { if err != nil {
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("searchAccountByUsernameDomain: error checking database for account %s@%s: %w", username, domain, err) return nil, fmt.Errorf("searchAccountByUsernameDomain: error checking database for account %s@%s: %w", username, domain, err)

View file

@ -32,7 +32,7 @@ import (
// BookmarkCreate adds a bookmark for the requestingAccount, targeting the given status (no-op if bookmark already exists). // BookmarkCreate adds a bookmark for the requestingAccount, targeting the given status (no-op if bookmark already exists).
func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
} }
@ -50,7 +50,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
// first check if the status is already bookmarked, if so we don't need to do anything // first check if the status is already bookmarked, if so we don't need to do anything
newBookmark := true newBookmark := true
gtsBookmark := &gtsmodel.StatusBookmark{} gtsBookmark := &gtsmodel.StatusBookmark{}
if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
// we already have a bookmark for this status // we already have a bookmark for this status
newBookmark = false newBookmark = false
} }
@ -67,7 +67,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
Status: targetStatus, Status: targetStatus,
} }
if err := p.db.Put(ctx, gtsBookmark); err != nil { if err := p.state.DB.Put(ctx, gtsBookmark); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting bookmark in database: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting bookmark in database: %s", err))
} }
} }
@ -83,7 +83,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
// BookmarkRemove removes a bookmark for the requesting account, targeting the given status (no-op if bookmark doesn't exist). // BookmarkRemove removes a bookmark for the requesting account, targeting the given status (no-op if bookmark doesn't exist).
func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
} }
@ -101,13 +101,13 @@ func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmo
// first check if the status is actually bookmarked // first check if the status is actually bookmarked
toUnbookmark := false toUnbookmark := false
gtsBookmark := &gtsmodel.StatusBookmark{} gtsBookmark := &gtsmodel.StatusBookmark{}
if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
// we have a bookmark for this status // we have a bookmark for this status
toUnbookmark = true toUnbookmark = true
} }
if toUnbookmark { if toUnbookmark {
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil { if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err)) return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err))
} }
} }

View file

@ -33,7 +33,7 @@ import (
// BoostCreate processes the boost/reblog of a given status, returning the newly-created boost if all is well. // BoostCreate processes the boost/reblog of a given status, returning the newly-created boost if all is well.
func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
} }
@ -47,7 +47,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
// boost boosts, and it looks absolutely bizarre in the UI // boost boosts, and it looks absolutely bizarre in the UI
if targetStatus.BoostOfID != "" { if targetStatus.BoostOfID != "" {
if targetStatus.BoostOf == nil { if targetStatus.BoostOf == nil {
b, err := p.db.GetStatusByID(ctx, targetStatus.BoostOfID) b, err := p.state.DB.GetStatusByID(ctx, targetStatus.BoostOfID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("couldn't fetch boosted status %s", targetStatus.BoostOfID)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("couldn't fetch boosted status %s", targetStatus.BoostOfID))
} }
@ -74,12 +74,12 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
boostWrapperStatus.BoostOfAccount = targetStatus.Account boostWrapperStatus.BoostOfAccount = targetStatus.Account
// put the boost in the database // put the boost in the database
if err := p.db.PutStatus(ctx, boostWrapperStatus); err != nil { if err := p.state.DB.PutStatus(ctx, boostWrapperStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// send it back to the processor for async processing // send it back to the processor for async processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityAnnounce, APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: boostWrapperStatus, GTSModel: boostWrapperStatus,
@ -98,7 +98,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
// BoostRemove processes the unboost/unreblog of a given status, returning the status if all is well. // BoostRemove processes the unboost/unreblog of a given status, returning the status if all is well.
func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
} }
@ -128,7 +128,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
Value: requestingAccount.ID, Value: requestingAccount.ID,
}, },
} }
err = p.db.GetWhere(ctx, where, gtsBoost) err = p.state.DB.GetWhere(ctx, where, gtsBoost)
if err == nil { if err == nil {
// we have a boost // we have a boost
toUnboost = true toUnboost = true
@ -151,7 +151,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
gtsBoost.BoostOf.Account = targetStatus.Account gtsBoost.BoostOf.Account = targetStatus.Account
// send it back to the processor for async processing // send it back to the processor for async processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ActivityAnnounce, APObjectType: ap.ActivityAnnounce,
APActivityType: ap.ActivityUndo, APActivityType: ap.ActivityUndo,
GTSModel: gtsBoost, GTSModel: gtsBoost,
@ -170,7 +170,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
// StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings. // StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings.
func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil { if err != nil {
wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", targetStatusID, err) wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", targetStatusID, err)
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
@ -181,7 +181,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
if boostOfID := targetStatus.BoostOfID; boostOfID != "" { if boostOfID := targetStatus.BoostOfID; boostOfID != "" {
// the target status is a boost wrapper, redirect this request to the status it boosts // the target status is a boost wrapper, redirect this request to the status it boosts
boostedStatus, err := p.db.GetStatusByID(ctx, boostOfID) boostedStatus, err := p.state.DB.GetStatusByID(ctx, boostOfID)
if err != nil { if err != nil {
wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", boostOfID, err) wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", boostOfID, err)
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {
@ -202,7 +202,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
} }
statusReblogs, err := p.db.GetStatusReblogs(ctx, targetStatus) statusReblogs, err := p.state.DB.GetStatusReblogs(ctx, targetStatus)
if err != nil { if err != nil {
err = fmt.Errorf("BoostedBy: error seeing who boosted status: %s", err) err = fmt.Errorf("BoostedBy: error seeing who boosted status: %s", err)
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
@ -211,7 +211,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
// filter account IDs so the user doesn't see accounts they blocked or which blocked them // filter account IDs so the user doesn't see accounts they blocked or which blocked them
accountIDs := make([]string, 0, len(statusReblogs)) accountIDs := make([]string, 0, len(statusReblogs))
for _, s := range statusReblogs { for _, s := range statusReblogs {
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true) blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)
if err != nil { if err != nil {
err = fmt.Errorf("BoostedBy: error checking blocks: %s", err) err = fmt.Errorf("BoostedBy: error checking blocks: %s", err)
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
@ -226,7 +226,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
// fetch accounts + create their API representations // fetch accounts + create their API representations
apiAccounts := make([]*apimodel.Account, 0, len(accountIDs)) apiAccounts := make([]*apimodel.Account, 0, len(accountIDs))
for _, accountID := range accountIDs { for _, accountID := range accountIDs {
account, err := p.db.GetAccountByID(ctx, accountID) account, err := p.state.DB.GetAccountByID(ctx, accountID)
if err != nil { if err != nil {
wrapped := fmt.Errorf("BoostedBy: error fetching account %s: %s", accountID, err) wrapped := fmt.Errorf("BoostedBy: error fetching account %s: %s", accountID, err)
if !errors.Is(err, db.ErrNoEntries) { if !errors.Is(err, db.ErrNoEntries) {

View file

@ -61,11 +61,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli
Text: form.Status, Text: form.Status,
} }
if errWithCode := processReplyToID(ctx, p.db, form, account.ID, newStatus); errWithCode != nil { if errWithCode := processReplyToID(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
if errWithCode := processMediaIDs(ctx, p.db, form, account.ID, newStatus); errWithCode != nil { if errWithCode := processMediaIDs(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
@ -77,17 +77,17 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
if err := processContent(ctx, p.db, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil { if err := processContent(ctx, p.state.DB, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// put the new status in the database // put the new status in the database
if err := p.db.PutStatus(ctx, newStatus); err != nil { if err := p.state.DB.PutStatus(ctx, newStatus); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }
// send it back to the processor for async processing // send it back to the processor for async processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate, APActivityType: ap.ActivityCreate,
GTSModel: newStatus, GTSModel: newStatus,

View file

@ -32,7 +32,7 @@ import (
// Delete processes the delete of a given status, returning the deleted status if the delete goes through. // Delete processes the delete of a given status, returning the deleted status if the delete goes through.
func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
if err != nil { if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
} }
@ -50,7 +50,7 @@ func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Acco
} }
// send the status back to the processor for async processing // send the status back to the processor for async processing
p.clientWorker.Queue(messages.FromClientAPI{ p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote, APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityDelete, APActivityType: ap.ActivityDelete,
GTSModel: targetStatus, GTSModel: targetStatus,

Some files were not shown because too many files have changed in this diff Show more