From 7293d6029b43db693fd170c0c087394339da0677 Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:00:35 +0100 Subject: [PATCH] [feature] add paging to account follows, followers and follow requests endpoints (#2186) --- docs/api/swagger.yaml | 98 ++++- go.mod | 1 + go.sum | 2 + internal/api/client/accounts/follow_test.go | 411 ++++++++++++++++++ internal/api/client/accounts/followers.go | 62 ++- internal/api/client/accounts/following.go | 62 ++- internal/api/client/blocks/blocksget.go | 43 +- .../api/client/followrequests/authorize.go | 2 +- internal/api/client/followrequests/get.go | 59 ++- .../api/client/followrequests/get_test.go | 220 +++++++++- internal/api/client/followrequests/reject.go | 2 +- internal/db/bundb/relationship.go | 101 ++--- internal/db/bundb/relationship_test.go | 6 +- internal/db/bundb/timeline.go | 1 + internal/db/bundb/timeline_test.go | 2 +- internal/db/bundb/util.go | 25 ++ internal/db/relationship.go | 33 +- internal/federation/federatingdb/followers.go | 2 +- internal/federation/federatingdb/following.go | 2 +- .../federation/federatingdb/following_test.go | 4 +- internal/federation/federatingdb/inbox.go | 2 +- internal/paging/boundary.go | 48 +- internal/paging/page.go | 144 +++--- internal/paging/page_test.go | 12 +- internal/paging/parse.go | 57 ++- internal/paging/response.go | 8 +- internal/paging/response_test.go | 32 +- internal/paging/util.go | 6 - internal/processing/account/account.go | 6 + internal/processing/account/account_test.go | 4 +- internal/processing/account/block.go | 50 +++ internal/processing/account/delete.go | 8 +- internal/processing/account/follow.go | 63 +-- internal/processing/account/follow_request.go | 119 +++++ internal/processing/account/relationships.go | 166 ++++--- internal/processing/blocks.go | 86 ---- internal/processing/common/account.go.go | 238 ++++++++++ internal/processing/common/common.go | 50 +++ internal/processing/common/status.go | 248 +++++++++++ internal/processing/followrequest.go | 123 ------ internal/processing/followrequest_test.go | 76 ++-- internal/processing/processor.go | 4 +- internal/timeline/get_test.go | 2 + testrig/testmodels.go | 4 + .../tomnomnom/linkheader/.gitignore | 2 + .../tomnomnom/linkheader/.travis.yml | 6 + .../tomnomnom/linkheader/CONTRIBUTING.mkd | 10 + .../github.com/tomnomnom/linkheader/LICENSE | 21 + .../tomnomnom/linkheader/README.mkd | 35 ++ .../github.com/tomnomnom/linkheader/main.go | 151 +++++++ vendor/modules.txt | 3 + 51 files changed, 2281 insertions(+), 641 deletions(-) create mode 100644 internal/processing/account/follow_request.go delete mode 100644 internal/processing/blocks.go create mode 100644 internal/processing/common/account.go.go create mode 100644 internal/processing/common/common.go create mode 100644 internal/processing/common/status.go delete mode 100644 internal/processing/followrequest.go create mode 100644 vendor/github.com/tomnomnom/linkheader/.gitignore create mode 100644 vendor/github.com/tomnomnom/linkheader/.travis.yml create mode 100644 vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd create mode 100644 vendor/github.com/tomnomnom/linkheader/LICENSE create mode 100644 vendor/github.com/tomnomnom/linkheader/README.mkd create mode 100644 vendor/github.com/tomnomnom/linkheader/main.go diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml index d9bf40b06..e522cdb2a 100644 --- a/docs/api/swagger.yaml +++ b/docs/api/swagger.yaml @@ -3072,6 +3072,13 @@ paths: - accounts /api/v1/accounts/{id}/followers: get: + description: |- + The next and previous queries can be parsed from the returned Link header. + Example: + + ``` + ; rel="next", ; rel="prev" + ```` operationId: accountFollowers parameters: - description: Account ID. @@ -3079,6 +3086,25 @@ paths: name: id required: true type: string + - description: 'Return only follower accounts *OLDER* than the given max ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' + in: query + name: max_id + type: string + - description: 'Return only follower accounts *NEWER* than the given since ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' + in: query + name: since_id + type: string + - description: 'Return only follower accounts *IMMEDIATELY NEWER* than the given min ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' + in: query + name: min_id + type: string + - default: 40 + description: Number of follower accounts to return. + in: query + maximum: 80 + minimum: 1 + name: limit + type: integer produces: - application/json responses: @@ -3106,6 +3132,13 @@ paths: - accounts /api/v1/accounts/{id}/following: get: + description: |- + The next and previous queries can be parsed from the returned Link header. + Example: + + ``` + ; rel="next", ; rel="prev" + ```` operationId: accountFollowing parameters: - description: Account ID. @@ -3113,6 +3146,25 @@ paths: name: id required: true type: string + - description: 'Return only following accounts *OLDER* than the given max ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' + in: query + name: max_id + type: string + - description: 'Return only following accounts *NEWER* than the given since ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' + in: query + name: since_id + type: string + - description: 'Return only following accounts *IMMEDIATELY NEWER* than the given min ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' + in: query + name: min_id + type: string + - default: 40 + description: Number of following accounts to return. + in: query + maximum: 80 + minimum: 1 + name: limit + type: integer produces: - application/json responses: @@ -4679,19 +4731,25 @@ paths: ```` operationId: blocksGet parameters: - - default: 20 - description: Number of blocks to return. - in: query - name: limit - type: integer - - description: Return only blocks *OLDER* than the given block ID. The block with the specified ID will not be included in the response. + - description: 'Return only blocked accounts *OLDER* than the given max ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.' in: query name: max_id type: string - - description: Return only blocks *NEWER* than the given block ID. The block with the specified ID will not be included in the response. + - description: 'Return only blocked accounts *NEWER* than the given since ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.' in: query name: since_id type: string + - description: 'Return only blocked accounts *IMMEDIATELY NEWER* than the given min ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.' + in: query + name: min_id + type: string + - default: 40 + description: Number of blocked accounts to return. + in: query + maximum: 80 + minimum: 1 + name: limit + type: integer produces: - application/json responses: @@ -4857,12 +4915,32 @@ paths: - featured_tags /api/v1/follow_requests: get: - description: Accounts will be sorted in order of follow request date descending (newest first). + description: |- + The next and previous queries can be parsed from the returned Link header. + Example: + + ``` + ; rel="next", ; rel="prev" + ```` operationId: getFollowRequests parameters: - - default: 40 - description: Number of accounts to return. + - description: 'Return only follow requesting accounts *OLDER* than the given max ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.' in: query + name: max_id + type: string + - description: 'Return only follow requesting accounts *NEWER* than the given since ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.' + in: query + name: since_id + type: string + - description: 'Return only follow requesting accounts *IMMEDIATELY NEWER* than the given min ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.' + in: query + name: min_id + type: string + - default: 40 + description: Number of follow requesting accounts to return. + in: query + maximum: 80 + minimum: 1 name: limit type: integer produces: diff --git a/go.mod b/go.mod index 2a6658319..db2a4c3b1 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/superseriousbusiness/exif-terminator v0.5.0 github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB.0.20230227143000-f4900831d6c8 github.com/tdewolff/minify/v2 v2.12.9 + github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 github.com/ulule/limiter/v3 v3.11.2 github.com/uptrace/bun v1.1.15 github.com/uptrace/bun/dialect/pgdialect v1.1.15 diff --git a/go.sum b/go.sum index 0da102d44..de9eff1ee 100644 --- a/go.sum +++ b/go.sum @@ -568,6 +568,8 @@ github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= +github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= +github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= diff --git a/internal/api/client/accounts/follow_test.go b/internal/api/client/accounts/follow_test.go index 9660acd4f..47526da1d 100644 --- a/internal/api/client/accounts/follow_test.go +++ b/internal/api/client/accounts/follow_test.go @@ -18,21 +18,33 @@ package accounts_test import ( + "context" + "encoding/json" "fmt" "io/ioutil" + "math/rand" "net/http" "net/http/httptest" + "net/url" + "strconv" "strings" "testing" + "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" + "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/testrig" + "github.com/tomnomnom/linkheader" ) +// random reader according to current-time source seed. +var randRd = rand.New(rand.NewSource(time.Now().Unix())) + type FollowTestSuite struct { AccountStandardTestSuite } @@ -69,6 +81,405 @@ func (suite *FollowTestSuite) TestFollowSelf() { assert.NoError(suite.T(), err) } +func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit2() { + suite.testGetFollowersPage(2, "backward") +} + +func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit4() { + suite.testGetFollowersPage(4, "backward") +} + +func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit6() { + suite.testGetFollowersPage(6, "backward") +} + +func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit2() { + suite.testGetFollowersPage(2, "forward") +} + +func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit4() { + suite.testGetFollowersPage(4, "forward") +} + +func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit6() { + suite.testGetFollowersPage(6, "forward") +} + +func (suite *FollowTestSuite) testGetFollowersPage(limit int, direction string) { + ctx := context.Background() + + // The authed local account we are going to use for HTTP requests + requestingAccount := suite.testAccounts["local_account_1"] + suite.clearAccountRelations(requestingAccount.ID) + + // Get current time. + now := time.Now() + + var i int + + for _, targetAccount := range suite.testAccounts { + if targetAccount.ID == requestingAccount.ID { + // we cannot be our own target... + continue + } + + // Get next simple ID. + id := strconv.Itoa(i) + i++ + + // put a follow in the database + err := suite.db.PutFollow(ctx, >smodel.Follow{ + ID: id, + CreatedAt: now, + UpdatedAt: now, + URI: fmt.Sprintf("%s/follow/%s", targetAccount.URI, id), + AccountID: targetAccount.ID, + TargetAccountID: requestingAccount.ID, + }) + suite.NoError(err) + + // Bump now by 1 second. + now = now.Add(time.Second) + } + + // Get _ALL_ follows we expect to see without any paging (this filters invisible). + apiRsp, err := suite.processor.Account().FollowersGet(ctx, requestingAccount, requestingAccount.ID, nil) + suite.NoError(err) + expectAccounts := apiRsp.Items // interfaced{} account slice + + // Iteratively set + // link query string. + var query string + + switch direction { + case "backward": + // Set the starting query to page backward from newest. + acc := expectAccounts[0].(*model.Account) + newest, _ := suite.db.GetFollow(ctx, acc.ID, requestingAccount.ID) + expectAccounts = expectAccounts[1:] + query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID) + + case "forward": + // Set the starting query to page forward from the oldest. + acc := expectAccounts[len(expectAccounts)-1].(*model.Account) + oldest, _ := suite.db.GetFollow(ctx, acc.ID, requestingAccount.ID) + expectAccounts = expectAccounts[:len(expectAccounts)-1] + query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID) + } + + for p := 0; ; p++ { + // Prepare new request for endpoint + recorder := httptest.NewRecorder() + endpoint := fmt.Sprintf("/api/v1/accounts/%s/followers", requestingAccount.ID) + ctx := suite.newContext(recorder, http.MethodGet, []byte{}, endpoint, "") + ctx.Params = gin.Params{{Key: "id", Value: requestingAccount.ID}} + ctx.Request.URL.RawQuery = query // setting provided next query value + + // call the handler and check for valid response code. + suite.T().Logf("direction=%q page=%d query=%q", direction, p, query) + suite.accountsModule.AccountFollowersGETHandler(ctx) + suite.Equal(http.StatusOK, recorder.Code) + + var accounts []*model.Account + + // Decode response body into API account models + result := recorder.Result() + dec := json.NewDecoder(result.Body) + err := dec.Decode(&accounts) + suite.NoError(err) + _ = result.Body.Close() + + var ( + + // start provides the starting index for loop in accounts. + start func([]*model.Account) int + + // iter performs the loop iter step with index. + iter func(int) int + + // check performs the loop conditional check against index and accounts. + check func(int, []*model.Account) bool + + // expect pulls the next account to check against from expectAccounts. + expect func([]interface{}) interface{} + + // trunc drops the last checked account from expectAccounts. + trunc func([]interface{}) []interface{} + ) + + switch direction { + case "backward": + // When paging backwards (DESC) we: + // - iter from end of received accounts + // - iterate backward through received accounts + // - stop when we reach last index of received accounts + // - compare each received with the first index of expected accounts + // - after each compare, drop the first index of expected accounts + start = func([]*model.Account) int { return 0 } + iter = func(i int) int { return i + 1 } + check = func(idx int, i []*model.Account) bool { return idx < len(i) } + expect = func(i []interface{}) interface{} { return i[0] } + trunc = func(i []interface{}) []interface{} { return i[1:] } + + case "forward": + // When paging forwards (ASC) we: + // - iter from end of received accounts + // - iterate backward through received accounts + // - stop when we reach first index of received accounts + // - compare each received with the last index of expected accounts + // - after each compare, drop the last index of expected accounts + start = func(i []*model.Account) int { return len(i) - 1 } + iter = func(i int) int { return i - 1 } + check = func(idx int, i []*model.Account) bool { return idx >= 0 } + expect = func(i []interface{}) interface{} { return i[len(i)-1] } + trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] } + } + + for i := start(accounts); check(i, accounts); i = iter(i) { + // Get next expected account. + iface := expect(expectAccounts) + + // Check that expected account matches received. + expectAccID := iface.(*model.Account).ID + receivdAccID := accounts[i].ID + suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p) + + // Drop checked from expected accounts. + expectAccounts = trunc(expectAccounts) + } + + if len(expectAccounts) == 0 { + // Reached end. + break + } + + // Parse response link header values. + values := result.Header.Values("Link") + links := linkheader.ParseMultiple(values) + filteredLinks := links.FilterByRel("next") + suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p) + + // A ref link header was set. + link := filteredLinks[0] + + // Parse URI from URI string. + uri, err := url.Parse(link.URL) + suite.NoError(err) + + // Set next raw query value. + query = uri.RawQuery + } +} + +func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit2() { + suite.testGetFollowingPage(2, "backward") +} + +func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit4() { + suite.testGetFollowingPage(4, "backward") +} + +func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit6() { + suite.testGetFollowingPage(6, "backward") +} + +func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit2() { + suite.testGetFollowingPage(2, "forward") +} + +func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit4() { + suite.testGetFollowingPage(4, "forward") +} + +func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit6() { + suite.testGetFollowingPage(6, "forward") +} + +func (suite *FollowTestSuite) testGetFollowingPage(limit int, direction string) { + ctx := context.Background() + + // The authed local account we are going to use for HTTP requests + requestingAccount := suite.testAccounts["local_account_1"] + suite.clearAccountRelations(requestingAccount.ID) + + // Get current time. + now := time.Now() + + var i int + + for _, targetAccount := range suite.testAccounts { + if targetAccount.ID == requestingAccount.ID { + // we cannot be our own target... + continue + } + + // Get next simple ID. + id := strconv.Itoa(i) + i++ + + // put a follow in the database + err := suite.db.PutFollow(ctx, >smodel.Follow{ + ID: id, + CreatedAt: now, + UpdatedAt: now, + URI: fmt.Sprintf("%s/follow/%s", requestingAccount.URI, id), + AccountID: requestingAccount.ID, + TargetAccountID: targetAccount.ID, + }) + suite.NoError(err) + + // Bump now by 1 second. + now = now.Add(time.Second) + } + + // Get _ALL_ follows we expect to see without any paging (this filters invisible). + apiRsp, err := suite.processor.Account().FollowingGet(ctx, requestingAccount, requestingAccount.ID, nil) + suite.NoError(err) + expectAccounts := apiRsp.Items // interfaced{} account slice + + // Iteratively set + // link query string. + var query string + + switch direction { + case "backward": + // Set the starting query to page backward from newest. + acc := expectAccounts[0].(*model.Account) + newest, _ := suite.db.GetFollow(ctx, requestingAccount.ID, acc.ID) + expectAccounts = expectAccounts[1:] + query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID) + + case "forward": + // Set the starting query to page forward from the oldest. + acc := expectAccounts[len(expectAccounts)-1].(*model.Account) + oldest, _ := suite.db.GetFollow(ctx, requestingAccount.ID, acc.ID) + expectAccounts = expectAccounts[:len(expectAccounts)-1] + query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID) + } + + for p := 0; ; p++ { + // Prepare new request for endpoint + recorder := httptest.NewRecorder() + endpoint := fmt.Sprintf("/api/v1/accounts/%s/following", requestingAccount.ID) + ctx := suite.newContext(recorder, http.MethodGet, []byte{}, endpoint, "") + ctx.Params = gin.Params{{Key: "id", Value: requestingAccount.ID}} + ctx.Request.URL.RawQuery = query // setting provided next query value + + // call the handler and check for valid response code. + suite.T().Logf("direction=%q page=%d query=%q", direction, p, query) + suite.accountsModule.AccountFollowingGETHandler(ctx) + suite.Equal(http.StatusOK, recorder.Code) + + var accounts []*model.Account + + // Decode response body into API account models + result := recorder.Result() + dec := json.NewDecoder(result.Body) + err := dec.Decode(&accounts) + suite.NoError(err) + _ = result.Body.Close() + + var ( + // start provides the starting index for loop in accounts. + start func([]*model.Account) int + + // iter performs the loop iter step with index. + iter func(int) int + + // check performs the loop conditional check against index and accounts. + check func(int, []*model.Account) bool + + // expect pulls the next account to check against from expectAccounts. + expect func([]interface{}) interface{} + + // trunc drops the last checked account from expectAccounts. + trunc func([]interface{}) []interface{} + ) + + switch direction { + case "backward": + // When paging backwards (DESC) we: + // - iter from end of received accounts + // - iterate backward through received accounts + // - stop when we reach last index of received accounts + // - compare each received with the first index of expected accounts + // - after each compare, drop the first index of expected accounts + start = func([]*model.Account) int { return 0 } + iter = func(i int) int { return i + 1 } + check = func(idx int, i []*model.Account) bool { return idx < len(i) } + expect = func(i []interface{}) interface{} { return i[0] } + trunc = func(i []interface{}) []interface{} { return i[1:] } + + case "forward": + // When paging forwards (ASC) we: + // - iter from end of received accounts + // - iterate backward through received accounts + // - stop when we reach first index of received accounts + // - compare each received with the last index of expected accounts + // - after each compare, drop the last index of expected accounts + start = func(i []*model.Account) int { return len(i) - 1 } + iter = func(i int) int { return i - 1 } + check = func(idx int, i []*model.Account) bool { return idx >= 0 } + expect = func(i []interface{}) interface{} { return i[len(i)-1] } + trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] } + } + + for i := start(accounts); check(i, accounts); i = iter(i) { + // Get next expected account. + iface := expect(expectAccounts) + + // Check that expected account matches received. + expectAccID := iface.(*model.Account).ID + receivdAccID := accounts[i].ID + suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p) + + // Drop checked from expected accounts. + expectAccounts = trunc(expectAccounts) + } + + if len(expectAccounts) == 0 { + // Reached end. + break + } + + // Parse response link header values. + values := result.Header.Values("Link") + links := linkheader.ParseMultiple(values) + filteredLinks := links.FilterByRel("next") + suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p) + + // A ref link header was set. + link := filteredLinks[0] + + // Parse URI from URI string. + uri, err := url.Parse(link.URL) + suite.NoError(err) + + // Set next raw query value. + query = uri.RawQuery + } +} + +func (suite *FollowTestSuite) clearAccountRelations(id string) { + // Esnure no account blocks exist between accounts. + _ = suite.db.DeleteAccountBlocks( + context.Background(), + id, + ) + + // Ensure no account follows exist between accounts. + _ = suite.db.DeleteAccountFollows( + context.Background(), + id, + ) + + // Ensure no account follow_requests exist between accounts. + _ = suite.db.DeleteAccountFollowRequests( + context.Background(), + id, + ) +} + func TestFollowTestSuite(t *testing.T) { suite.Run(t, new(FollowTestSuite)) } diff --git a/internal/api/client/accounts/followers.go b/internal/api/client/accounts/followers.go index 96b034877..2448bc50a 100644 --- a/internal/api/client/accounts/followers.go +++ b/internal/api/client/accounts/followers.go @@ -25,12 +25,20 @@ import ( apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // AccountFollowersGETHandler swagger:operation GET /api/v1/accounts/{id}/followers accountFollowers // // See followers of account with given id. // +// The next and previous queries can be parsed from the returned Link header. +// Example: +// +// ``` +// ; rel="next", ; rel="prev" +// ```` +// // --- // tags: // - accounts @@ -45,6 +53,42 @@ import ( // description: Account ID. // in: path // required: true +// - +// name: max_id +// type: string +// description: >- +// Return only follower accounts *OLDER* than the given max ID. +// The follower account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: since_id +// type: string +// description: >- +// Return only follower accounts *NEWER* than the given since ID. +// The follower account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: min_id +// type: string +// description: >- +// Return only follower accounts *IMMEDIATELY NEWER* than the given min ID. +// The follower account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: limit +// type: integer +// description: Number of follower accounts to return. +// default: 40 +// minimum: 1 +// maximum: 80 +// in: query +// required: false // // security: // - OAuth2 Bearer: @@ -87,11 +131,25 @@ func (m *Module) AccountFollowersGETHandler(c *gin.Context) { return } - followers, errWithCode := m.processor.Account().FollowersGet(c.Request.Context(), authed.Account, targetAcctID) + page, errWithCode := paging.ParseIDPage(c, + 1, // min limit + 80, // max limit + 40, // default limit + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - c.JSON(http.StatusOK, followers) + resp, errWithCode := m.processor.Account().FollowersGet(c.Request.Context(), authed.Account, targetAcctID, page) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if resp.LinkHeader != "" { + c.Header("Link", resp.LinkHeader) + } + + c.JSON(http.StatusOK, resp.Items) } diff --git a/internal/api/client/accounts/following.go b/internal/api/client/accounts/following.go index 122a12a6e..d106d6ea6 100644 --- a/internal/api/client/accounts/following.go +++ b/internal/api/client/accounts/following.go @@ -25,12 +25,20 @@ import ( apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // AccountFollowingGETHandler swagger:operation GET /api/v1/accounts/{id}/following accountFollowing // // See accounts followed by given account id. // +// The next and previous queries can be parsed from the returned Link header. +// Example: +// +// ``` +// ; rel="next", ; rel="prev" +// ```` +// // --- // tags: // - accounts @@ -45,6 +53,42 @@ import ( // description: Account ID. // in: path // required: true +// - +// name: max_id +// type: string +// description: >- +// Return only following accounts *OLDER* than the given max ID. +// The following account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: since_id +// type: string +// description: >- +// Return only following accounts *NEWER* than the given since ID. +// The following account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: min_id +// type: string +// description: >- +// Return only following accounts *IMMEDIATELY NEWER* than the given min ID. +// The following account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: limit +// type: integer +// description: Number of following accounts to return. +// default: 40 +// minimum: 1 +// maximum: 80 +// in: query +// required: false // // security: // - OAuth2 Bearer: @@ -87,11 +131,25 @@ func (m *Module) AccountFollowingGETHandler(c *gin.Context) { return } - following, errWithCode := m.processor.Account().FollowingGet(c.Request.Context(), authed.Account, targetAcctID) + page, errWithCode := paging.ParseIDPage(c, + 1, // min limit + 80, // max limit + 40, // default limit + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - c.JSON(http.StatusOK, following) + resp, errWithCode := m.processor.Account().FollowingGet(c.Request.Context(), authed.Account, targetAcctID, page) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if resp.LinkHeader != "" { + c.Header("Link", resp.LinkHeader) + } + + c.JSON(http.StatusOK, resp.Items) } diff --git a/internal/api/client/blocks/blocksget.go b/internal/api/client/blocks/blocksget.go index dcf70e9cf..0761160bc 100644 --- a/internal/api/client/blocks/blocksget.go +++ b/internal/api/client/blocks/blocksget.go @@ -47,25 +47,40 @@ import ( // // parameters: // - -// name: limit -// type: integer -// description: Number of blocks to return. -// default: 20 -// in: query -// - // name: max_id // type: string // description: >- -// Return only blocks *OLDER* than the given block ID. -// The block with the specified ID will not be included in the response. +// Return only blocked accounts *OLDER* than the given max ID. +// The blocked account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal block, NOT any of the returned accounts. // in: query +// required: false // - // name: since_id // type: string // description: >- -// Return only blocks *NEWER* than the given block ID. -// The block with the specified ID will not be included in the response. +// Return only blocked accounts *NEWER* than the given since ID. +// The blocked account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal block, NOT any of the returned accounts. // in: query +// - +// name: min_id +// type: string +// description: >- +// Return only blocked accounts *IMMEDIATELY NEWER* than the given min ID. +// The blocked account with the specified ID will not be included in the response. +// NOTE: the ID is of the internal block, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: limit +// type: integer +// description: Number of blocked accounts to return. +// default: 40 +// minimum: 1 +// maximum: 80 +// in: query +// required: false // // security: // - OAuth2 Bearer: @@ -104,16 +119,16 @@ func (m *Module) BlocksGETHandler(c *gin.Context) { } page, errWithCode := paging.ParseIDPage(c, - 1, // min limit - 100, // max limit - 20, // default limit + 1, // min limit + 80, // max limit + 40, // default limit ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - resp, errWithCode := m.processor.BlocksGet( + resp, errWithCode := m.processor.Account().BlocksGet( c.Request.Context(), authed.Account, page, diff --git a/internal/api/client/followrequests/authorize.go b/internal/api/client/followrequests/authorize.go index 7a19c0f86..707d3db26 100644 --- a/internal/api/client/followrequests/authorize.go +++ b/internal/api/client/followrequests/authorize.go @@ -87,7 +87,7 @@ func (m *Module) FollowRequestAuthorizePOSTHandler(c *gin.Context) { return } - relationship, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID) + relationship, errWithCode := m.processor.Account().FollowRequestAccept(c.Request.Context(), authed.Account, originAccountID) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return diff --git a/internal/api/client/followrequests/get.go b/internal/api/client/followrequests/get.go index 628e3b807..af2f3741c 100644 --- a/internal/api/client/followrequests/get.go +++ b/internal/api/client/followrequests/get.go @@ -24,12 +24,19 @@ import ( apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // FollowRequestGETHandler swagger:operation GET /api/v1/follow_requests getFollowRequests // // Get an array of accounts that have requested to follow you. -// Accounts will be sorted in order of follow request date descending (newest first). +// +// The next and previous queries can be parsed from the returned Link header. +// Example: +// +// ``` +// ; rel="next", ; rel="prev" +// ```` // // --- // tags: @@ -40,11 +47,41 @@ import ( // // parameters: // - +// name: max_id +// type: string +// description: >- +// Return only follow requesting accounts *OLDER* than the given max ID. +// The follow requester with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow request, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: since_id +// type: string +// description: >- +// Return only follow requesting accounts *NEWER* than the given since ID. +// The follow requester with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow request, NOT any of the returned accounts. +// in: query +// required: false +// - +// name: min_id +// type: string +// description: >- +// Return only follow requesting accounts *IMMEDIATELY NEWER* than the given min ID. +// The follow requester with the specified ID will not be included in the response. +// NOTE: the ID is of the internal follow request, NOT any of the returned accounts. +// in: query +// required: false +// - // name: limit // type: integer -// description: Number of accounts to return. +// description: Number of follow requesting accounts to return. // default: 40 +// minimum: 1 +// maximum: 80 // in: query +// required: false // // security: // - OAuth2 Bearer: @@ -82,11 +119,25 @@ func (m *Module) FollowRequestGETHandler(c *gin.Context) { return } - accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed) + page, errWithCode := paging.ParseIDPage(c, + 1, // min limit + 80, // max limit + 40, // default limit + ) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } - c.JSON(http.StatusOK, accts) + resp, errWithCode := m.processor.Account().FollowRequestsGet(c.Request.Context(), authed.Account, page) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + if resp.LinkHeader != "" { + c.Header("Link", resp.LinkHeader) + } + + c.JSON(http.StatusOK, resp.Items) } diff --git a/internal/api/client/followrequests/get_test.go b/internal/api/client/followrequests/get_test.go index d95c9878c..f2fa832a1 100644 --- a/internal/api/client/followrequests/get_test.go +++ b/internal/api/client/followrequests/get_test.go @@ -22,17 +22,25 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" + "math/rand" "net/http" "net/http/httptest" + "net/url" + "strconv" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/tomnomnom/linkheader" ) +// random reader according to current-time source seed. +var randRd = rand.New(rand.NewSource(time.Now().Unix())) + type GetTestSuite struct { FollowRequestStandardTestSuite } @@ -68,7 +76,7 @@ func (suite *GetTestSuite) TestGet() { defer result.Body.Close() // check the response - b, err := ioutil.ReadAll(result.Body) + b, err := io.ReadAll(result.Body) assert.NoError(suite.T(), err) dst := new(bytes.Buffer) err = json.Indent(dst, b, "", " ") @@ -99,6 +107,214 @@ func (suite *GetTestSuite) TestGet() { ]`, dst.String()) } +func (suite *GetTestSuite) TestGetPageBackwardLimit2() { + suite.testGetPage(2, "backward") +} + +func (suite *GetTestSuite) TestGetPageBackwardLimit4() { + suite.testGetPage(4, "backward") +} + +func (suite *GetTestSuite) TestGetPageBackwardLimit6() { + suite.testGetPage(6, "backward") +} + +func (suite *GetTestSuite) TestGetPageForwardLimit2() { + suite.testGetPage(2, "forward") +} + +func (suite *GetTestSuite) TestGetPageForwardLimit4() { + suite.testGetPage(4, "forward") +} + +func (suite *GetTestSuite) TestGetPageForwardLimit6() { + suite.testGetPage(6, "forward") +} + +func (suite *GetTestSuite) testGetPage(limit int, direction string) { + ctx := context.Background() + + // The authed local account we are going to use for HTTP requests + requestingAccount := suite.testAccounts["local_account_1"] + suite.clearAccountRelations(requestingAccount.ID) + + // Get current time. + now := time.Now() + + var i int + + for _, targetAccount := range suite.testAccounts { + if targetAccount.ID == requestingAccount.ID { + // we cannot be our own target... + continue + } + + // Get next simple ID. + id := strconv.Itoa(i) + i++ + + // put a follow request in the database + err := suite.db.PutFollowRequest(ctx, >smodel.FollowRequest{ + ID: id, + CreatedAt: now, + UpdatedAt: now, + URI: fmt.Sprintf("%s/follow/%s", targetAccount.URI, id), + AccountID: targetAccount.ID, + TargetAccountID: requestingAccount.ID, + }) + suite.NoError(err) + + // Bump now by 1 second. + now = now.Add(time.Second) + } + + // Get _ALL_ follow requests we expect to see without any paging (this filters invisible). + apiRsp, err := suite.processor.Account().FollowRequestsGet(ctx, requestingAccount, nil) + suite.NoError(err) + expectAccounts := apiRsp.Items // interfaced{} account slice + + // Iteratively set + // link query string. + var query string + + switch direction { + case "backward": + // Set the starting query to page backward from newest. + acc := expectAccounts[0].(*model.Account) + newest, _ := suite.db.GetFollowRequest(ctx, acc.ID, requestingAccount.ID) + expectAccounts = expectAccounts[1:] + query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID) + + case "forward": + // Set the starting query to page forward from the oldest. + acc := expectAccounts[len(expectAccounts)-1].(*model.Account) + oldest, _ := suite.db.GetFollowRequest(ctx, acc.ID, requestingAccount.ID) + expectAccounts = expectAccounts[:len(expectAccounts)-1] + query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID) + } + + for p := 0; ; p++ { + // Prepare new request for endpoint + recorder := httptest.NewRecorder() + ctx := suite.newContext(recorder, http.MethodGet, []byte{}, "/api/v1/follow_requests", "") + ctx.Request.URL.RawQuery = query // setting provided next query value + + // call the handler and check for valid response code. + suite.T().Logf("direction=%q page=%d query=%q", direction, p, query) + suite.followRequestModule.FollowRequestGETHandler(ctx) + suite.Equal(http.StatusOK, recorder.Code) + + var accounts []*model.Account + + // Decode response body into API account models + result := recorder.Result() + dec := json.NewDecoder(result.Body) + err := dec.Decode(&accounts) + suite.NoError(err) + _ = result.Body.Close() + + var ( + + // start provides the starting index for loop in accounts. + start func([]*model.Account) int + + // iter performs the loop iter step with index. + iter func(int) int + + // check performs the loop conditional check against index and accounts. + check func(int, []*model.Account) bool + + // expect pulls the next account to check against from expectAccounts. + expect func([]interface{}) interface{} + + // trunc drops the last checked account from expectAccounts. + trunc func([]interface{}) []interface{} + ) + + switch direction { + case "backward": + // When paging backwards (DESC) we: + // - iter from end of received accounts + // - iterate backward through received accounts + // - stop when we reach last index of received accounts + // - compare each received with the first index of expected accounts + // - after each compare, drop the first index of expected accounts + start = func([]*model.Account) int { return 0 } + iter = func(i int) int { return i + 1 } + check = func(idx int, i []*model.Account) bool { return idx < len(i) } + expect = func(i []interface{}) interface{} { return i[0] } + trunc = func(i []interface{}) []interface{} { return i[1:] } + + case "forward": + // When paging forwards (ASC) we: + // - iter from end of received accounts + // - iterate backward through received accounts + // - stop when we reach first index of received accounts + // - compare each received with the last index of expected accounts + // - after each compare, drop the last index of expected accounts + start = func(i []*model.Account) int { return len(i) - 1 } + iter = func(i int) int { return i - 1 } + check = func(idx int, i []*model.Account) bool { return idx >= 0 } + expect = func(i []interface{}) interface{} { return i[len(i)-1] } + trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] } + } + + for i := start(accounts); check(i, accounts); i = iter(i) { + // Get next expected account. + iface := expect(expectAccounts) + + // Check that expected account matches received. + expectAccID := iface.(*model.Account).ID + receivdAccID := accounts[i].ID + suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p) + + // Drop checked from expected accounts. + expectAccounts = trunc(expectAccounts) + } + + if len(expectAccounts) == 0 { + // Reached end. + break + } + + // Parse response link header values. + values := result.Header.Values("Link") + links := linkheader.ParseMultiple(values) + filteredLinks := links.FilterByRel("next") + suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p) + + // A ref link header was set. + link := filteredLinks[0] + + // Parse URI from URI string. + uri, err := url.Parse(link.URL) + suite.NoError(err) + + // Set next raw query value. + query = uri.RawQuery + } +} + +func (suite *GetTestSuite) clearAccountRelations(id string) { + // Esnure no account blocks exist between accounts. + _ = suite.db.DeleteAccountBlocks( + context.Background(), + id, + ) + + // Ensure no account follows exist between accounts. + _ = suite.db.DeleteAccountFollows( + context.Background(), + id, + ) + + // Ensure no account follow_requests exist between accounts. + _ = suite.db.DeleteAccountFollowRequests( + context.Background(), + id, + ) +} + func TestGetTestSuite(t *testing.T) { suite.Run(t, &GetTestSuite{}) } diff --git a/internal/api/client/followrequests/reject.go b/internal/api/client/followrequests/reject.go index 3f75facba..6514a615e 100644 --- a/internal/api/client/followrequests/reject.go +++ b/internal/api/client/followrequests/reject.go @@ -85,7 +85,7 @@ func (m *Module) FollowRequestRejectPOSTHandler(c *gin.Context) { return } - relationship, errWithCode := m.processor.FollowRequestReject(c.Request.Context(), authed, originAccountID) + relationship, errWithCode := m.processor.Account().FollowRequestReject(c.Request.Context(), authed.Account, originAccountID) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index f1bdcf52b..822e697c1 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -102,8 +102,8 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount return &rel, nil } -func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - followIDs, err := r.getAccountFollowIDs(ctx, accountID) +func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { + followIDs, err := r.getAccountFollowIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -118,8 +118,8 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s return r.GetFollowsByIDs(ctx, followIDs) } -func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) +func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { + followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -134,16 +134,16 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID return r.GetFollowsByIDs(ctx, followerIDs) } -func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { - followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) +func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { + followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page) if err != nil { return nil, err } return r.GetFollowRequestsByIDs(ctx, followReqIDs) } -func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { - followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) +func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { + followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -151,39 +151,15 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account } func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { - // Load block IDs from cache with database loader callback. - blockIDs, err := r.state.Caches.GTS.BlockIDs().Load(accountID, func() ([]string, error) { - var blockIDs []string - - // Block IDs not in cache, perform DB query! - q := newSelectBlocks(r.db, accountID) - if _, err := q.Exec(ctx, &blockIDs); err != nil { - return nil, err - } - - return blockIDs, nil - }) + blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page) if err != nil { return nil, err } - - // Our cached / selected block IDs are - // ALWAYS stored in descending order. - // Depending on the paging requested - // this may be an unexpected order. - if !page.GetOrder().Ascending() { - blockIDs = paging.Reverse(blockIDs) - } - - // Page the resulting block IDs. - blockIDs = page.Page(blockIDs) - - // Convert these IDs to full block objects. return r.GetBlocksByIDs(ctx, blockIDs) } func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { - followIDs, err := r.getAccountFollowIDs(ctx, accountID) + followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil) return len(followIDs), err } @@ -193,7 +169,7 @@ func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID } func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { - followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) + followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil) return len(followerIDs), err } @@ -203,17 +179,22 @@ func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, account } func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { - followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) + followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil) return len(followReqIDs), err } func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { - followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) + followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil) return len(followReqIDs), err } -func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) { - return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) { +func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) { + blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil) + return len(blockIDs), err +} + +func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) { var followIDs []string // Follow IDs not in cache, perform DB query! @@ -240,8 +221,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID }) } -func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) { - return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) { +func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) { var followIDs []string // Follow IDs not in cache, perform DB query! @@ -268,8 +249,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account }) } -func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) { - return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) { +func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) { var followReqIDs []string // Follow request IDs not in cache, perform DB query! @@ -282,8 +263,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account }) } -func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) { - return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) { +func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) { var followReqIDs []string // Follow request IDs not in cache, perform DB query! @@ -296,13 +277,27 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco }) } +func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { + return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) { + var blockIDs []string + + // Block IDs not in cache, perform DB query! + q := newSelectBlocks(r.db, accountID) + if _, err := q.Exec(ctx, &blockIDs); err != nil { + return nil, err + } + + return blockIDs, nil + }) +} + // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("follow_requests")). ColumnExpr("?", bun.Ident("id")). Where("? = ?", bun.Ident("target_account_id"), accountID). - OrderExpr("? DESC", bun.Ident("updated_at")) + OrderExpr("? DESC", bun.Ident("id")) } // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. @@ -311,7 +306,7 @@ func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery { TableExpr("?", bun.Ident("follow_requests")). ColumnExpr("?", bun.Ident("id")). Where("? = ?", bun.Ident("target_account_id"), accountID). - OrderExpr("? DESC", bun.Ident("updated_at")) + OrderExpr("? DESC", bun.Ident("id")) } // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. @@ -320,7 +315,7 @@ func newSelectFollows(db *DB, accountID string) *bun.SelectQuery { Table("follows"). Column("id"). Where("? = ?", bun.Ident("account_id"), accountID). - OrderExpr("? DESC", bun.Ident("updated_at")) + OrderExpr("? DESC", bun.Ident("id")) } // newSelectLocalFollows returns a new select query for all rows in the follows table with @@ -338,7 +333,7 @@ func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery { Column("id"). Where("? IS NULL", bun.Ident("domain")), ). - OrderExpr("? DESC", bun.Ident("updated_at")) + OrderExpr("? DESC", bun.Ident("id")) } // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. @@ -347,7 +342,7 @@ func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery { Table("follows"). Column("id"). Where("? = ?", bun.Ident("target_account_id"), accountID). - OrderExpr("? DESC", bun.Ident("updated_at")) + OrderExpr("? DESC", bun.Ident("id")) } // newSelectLocalFollowers returns a new select query for all rows in the follows table with @@ -365,14 +360,14 @@ func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery { Column("id"). Where("? IS NULL", bun.Ident("domain")), ). - OrderExpr("? DESC", bun.Ident("updated_at")) + OrderExpr("? DESC", bun.Ident("id")) } // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("blocks")). - ColumnExpr("?", bun.Ident("?")). + ColumnExpr("?", bun.Ident("id")). Where("? = ?", bun.Ident("account_id"), accountID). - OrderExpr("? DESC", bun.Ident("updated_at")) + OrderExpr("? DESC", bun.Ident("id")) } diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index d7c93ff0e..aa2353961 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -753,14 +753,14 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() { suite.FailNow(err.Error()) } - followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) + followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID, nil) suite.NoError(err) suite.Len(followRequests, 1) } func (suite *RelationshipTestSuite) TestGetAccountFollows() { account := suite.testAccounts["local_account_1"] - follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) + follows, err := suite.db.GetAccountFollows(context.Background(), account.ID, nil) suite.NoError(err) suite.Len(follows, 2) } @@ -781,7 +781,7 @@ func (suite *RelationshipTestSuite) TestCountAccountFollows() { func (suite *RelationshipTestSuite) TestGetAccountFollowers() { account := suite.testAccounts["local_account_1"] - follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID) + follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil) suite.NoError(err) suite.Len(follows, 2) } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index f63937bc1..229245899 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -114,6 +114,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI follows, err := t.state.DB.GetAccountFollows( gtscontext.SetBarebones(ctx), accountID, + nil, // select all ) if err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, gtserror.Newf("db error getting follows for account %s: %w", accountID, err) diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go index e5a78dfd1..ac169ec4a 100644 --- a/internal/db/bundb/timeline_test.go +++ b/internal/db/bundb/timeline_test.go @@ -167,8 +167,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() { follows, err := suite.state.DB.GetAccountFollows( gtscontext.SetBarebones(ctx), viewingAccount.ID, + nil, // select all ) - if err != nil { suite.FailNow(err.Error()) } diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index 3c3249daf..1d820d081 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -20,7 +20,9 @@ package bundb import ( "strings" + "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/uptrace/bun" ) @@ -83,6 +85,29 @@ func whereStartsLike( ) } +// loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs. +// NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order. +func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) { + // Check cache for IDs, else load. + ids, err := cache.Load(key, loadDESC) + if err != nil { + return nil, err + } + + // Our cached / selected IDs are ALWAYS + // fetched from `loadDESC` in descending + // order. Depending on the paging requested + // this may be an unexpected order. + if page.GetOrder().Ascending() { + ids = paging.Reverse(ids) + } + + // Page the resulting IDs. + ids = page.Page(ids) + + return ids, nil +} + // updateWhere parses []db.Where and adds it to the given update query. func updateWhere(q *bun.UpdateQuery, where []db.Where) { for _, w := range where { diff --git a/internal/db/relationship.go b/internal/db/relationship.go index 91c98644c..b3b45551b 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -138,43 +138,46 @@ type Relationship interface { RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error // GetAccountFollows returns a slice of follows owned by the given accountID. - GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) + GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) // GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance. GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) + // GetAccountFollowers fetches follows that target given accountID. + GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) + + // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. + GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) + + // GetAccountFollowRequests returns all follow requests targeting the given account. + GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) + + // GetAccountFollowRequesting returns all follow requests originating from the given account. + GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) + + // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. + GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) + // CountAccountFollows returns the amount of accounts that the given accountID is following. CountAccountFollows(ctx context.Context, accountID string) (int, error) // CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance. CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) - // GetAccountFollowers fetches follows that target given accountID. - GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) - - // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. - GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) - // CountAccountFollowers returns the amounts that the given ID is followed by. CountAccountFollowers(ctx context.Context, accountID string) (int, error) // CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance. CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) - // GetAccountFollowRequests returns all follow requests targeting the given account. - GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) - - // GetAccountFollowRequesting returns all follow requests originating from the given account. - GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) - // CountAccountFollowRequests returns number of follow requests targeting the given account. CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) // CountAccountFollowerRequests returns number of follow requests originating from the given account. CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) - // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. - GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) + // CountAccountBlocks ... + CountAccountBlocks(ctx context.Context, accountID string) (int, error) // GetNote gets a private note from a source account on a target account, if it exists. GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go index 4ca2e2683..eada48c1b 100644 --- a/internal/federation/federatingdb/followers.go +++ b/internal/federation/federatingdb/followers.go @@ -38,7 +38,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow return nil, err } - follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID) + follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID, nil) if err != nil { return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err) } diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go index 391a2f810..deb965564 100644 --- a/internal/federation/federatingdb/following.go +++ b/internal/federation/federatingdb/following.go @@ -38,7 +38,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow return nil, err } - follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID) + follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID, nil) if err != nil { return nil, fmt.Errorf("Following: db error getting following for account id %s: %w", acct.ID, err) } diff --git a/internal/federation/federatingdb/following_test.go b/internal/federation/federatingdb/following_test.go index 83d1a72b5..93bc6d348 100644 --- a/internal/federation/federatingdb/following_test.go +++ b/internal/federation/federatingdb/following_test.go @@ -47,8 +47,8 @@ func (suite *FollowingTestSuite) TestGetFollowing() { suite.Equal(`{ "@context": "https://www.w3.org/ns/activitystreams", "items": [ - "http://localhost:8080/users/admin", - "http://localhost:8080/users/1happyturtle" + "http://localhost:8080/users/1happyturtle", + "http://localhost:8080/users/admin" ], "type": "Collection" }`, string(fJson)) diff --git a/internal/federation/federatingdb/inbox.go b/internal/federation/federatingdb/inbox.go index 18974ba79..9bd9f8d87 100644 --- a/internal/federation/federatingdb/inbox.go +++ b/internal/federation/federatingdb/inbox.go @@ -89,7 +89,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err) } - follows, err := f.state.DB.GetAccountFollowers(c, account.ID) + follows, err := f.state.DB.GetAccountFollowers(c, account.ID, nil) if err != nil { return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err) } diff --git a/internal/paging/boundary.go b/internal/paging/boundary.go index 2f202097b..15af65e0c 100644 --- a/internal/paging/boundary.go +++ b/internal/paging/boundary.go @@ -17,10 +17,10 @@ package paging -// MinID returns an ID boundary with given min ID value, +// EitherMinID returns an ID boundary with given min ID value, // using either the `since_id`,"DESC" name,ordering or // `min_id`,"ASC" name,ordering depending on which is set. -func MinID(minID, sinceID string) Boundary { +func EitherMinID(minID, sinceID string) Boundary { /* Paging with `since_id` vs `min_id`: @@ -47,18 +47,28 @@ func MinID(minID, sinceID string) Boundary { */ switch { case minID != "": - return Boundary{ - Name: "min_id", - Value: minID, - Order: OrderAscending, - } + return MinID(minID) default: // default min is `since_id` - return Boundary{ - Name: "since_id", - Value: sinceID, - Order: OrderDescending, - } + return SinceID(sinceID) + } +} + +// SinceID ... +func SinceID(sinceID string) Boundary { + return Boundary{ + Name: "since_id", + Value: sinceID, + Order: OrderDescending, + } +} + +// MinID ... +func MinID(minID string) Boundary { + return Boundary{ + Name: "min_id", + Value: minID, + Order: OrderAscending, } } @@ -111,7 +121,7 @@ func (b Boundary) new(value string) Boundary { // Find finds the boundary's set value in input slice, or returns -1. func (b Boundary) Find(in []string) int { - if zero(b.Value) { + if b.Value == "" { return -1 } for i := range in { @@ -121,15 +131,3 @@ func (b Boundary) Find(in []string) int { } return -1 } - -// Query returns this boundary as assembled query key=value pair. -func (b Boundary) Query() string { - switch { - case zero(b.Value): - return "" - case b.Name == "": - panic("value without boundary name") - default: - return b.Name + "=" + b.Value - } -} diff --git a/internal/paging/page.go b/internal/paging/page.go index 7d8f84aab..0a9bc71b1 100644 --- a/internal/paging/page.go +++ b/internal/paging/page.go @@ -20,7 +20,6 @@ package paging import ( "net/url" "strconv" - "strings" "golang.org/x/exp/slices" ) @@ -70,26 +69,10 @@ func (p *Page) GetOrder() Order { } func (p *Page) order() Order { - var ( - // Check if min/max values set. - minValue = zero(p.Min.Value) - maxValue = zero(p.Max.Value) - - // Check if min/max orders set. - minOrder = (p.Min.Order != 0) - maxOrder = (p.Max.Order != 0) - ) - switch { - // Boundaries with a value AND order set - // take priority. Min always comes first. - case minValue && minOrder: + case p.Min.Order != 0: return p.Min.Order - case maxValue && maxOrder: - return p.Max.Order - case minOrder: - return p.Min.Order - case maxOrder: + case p.Max.Order != 0: return p.Max.Order default: return 0 @@ -108,31 +91,9 @@ func (p *Page) Page(in []string) []string { return in } - if o := p.order(); !o.Ascending() { - // Default sort is descending, - // catching all cases when NOT - // ascending (even zero value). - // - // NOTE: sorted data does not always - // occur according to string ineqs - // so we unfortunately cannot check. - - if maxIdx := p.Max.Find(in); maxIdx != -1 { - // Reslice skipping up to max. - in = in[maxIdx+1:] - } - - if minIdx := p.Min.Find(in); minIdx != -1 { - // Reslice stripping past min. - in = in[:minIdx] - } - } else { + if p.order().Ascending() { // Sort type is ascending, input // data is assumed to be ascending. - // - // NOTE: sorted data does not always - // occur according to string ineqs - // so we unfortunately cannot check. if minIdx := p.Min.Find(in); minIdx != -1 { // Reslice skipping up to min. @@ -144,6 +105,11 @@ func (p *Page) Page(in []string) []string { in = in[:maxIdx] } + if p.Limit > 0 && p.Limit < len(in) { + // Reslice input to limit. + in = in[:p.Limit] + } + if len(in) > 1 { // Clone input before // any modifications. @@ -153,11 +119,25 @@ func (p *Page) Page(in []string) []string { // ALWAYS be descending. in = Reverse(in) } - } + } else { + // Default sort is descending, + // catching all cases when NOT + // ascending (even zero value). - if p.Limit > 0 && p.Limit < len(in) { - // Reslice input to limit. - in = in[:p.Limit] + if maxIdx := p.Max.Find(in); maxIdx != -1 { + // Reslice skipping up to max. + in = in[maxIdx+1:] + } + + if minIdx := p.Min.Find(in); minIdx != -1 { + // Reslice stripping past min. + in = in[:minIdx] + } + + if p.Limit > 0 && p.Limit < len(in) { + // Reslice input to limit. + in = in[:p.Limit] + } } return in @@ -165,8 +145,8 @@ func (p *Page) Page(in []string) []string { // Next creates a new instance for the next returnable page, using // given max value. This preserves original limit and max key name. -func (p *Page) Next(max string) *Page { - if p == nil || max == "" { +func (p *Page) Next(lo, hi string) *Page { + if p == nil || lo == "" || hi == "" { // no paging. return nil } @@ -177,16 +157,27 @@ func (p *Page) Next(max string) *Page { // Set original limit. p2.Limit = p.Limit - // Create new from old. - p2.Max = p.Max.new(max) + if p.order().Ascending() { + // When ascending, next page + // needs to start with min at + // the next highest value. + p2.Min = p.Min.new(hi) + p2.Max = p.Max.new("") + } else { + // When descending, next page + // needs to start with max at + // the next lowest value. + p2.Min = p.Min.new("") + p2.Max = p.Max.new(lo) + } return p2 } // Prev creates a new instance for the prev returnable page, using // given min value. This preserves original limit and min key name. -func (p *Page) Prev(min string) *Page { - if p == nil || min == "" { +func (p *Page) Prev(lo, hi string) *Page { + if p == nil || lo == "" || hi == "" { // no paging. return nil } @@ -197,55 +188,56 @@ func (p *Page) Prev(min string) *Page { // Set original limit. p2.Limit = p.Limit - // Create new from old. - p2.Min = p.Min.new(min) + if p.order().Ascending() { + // When ascending, prev page + // needs to start with max at + // the next lowest value. + p2.Min = p.Min.new("") + p2.Max = p.Max.new(lo) + } else { + // When descending, next page + // needs to start with max at + // the next lowest value. + p2.Min = p.Min.new(hi) + p2.Max = p.Max.new("") + } return p2 } // ToLink builds a URL link for given endpoint information and extra query parameters, // appending this Page's minimum / maximum boundaries and available limit (if any). -func (p *Page) ToLink(proto, host, path string, queryParams []string) string { +func (p *Page) ToLink(proto, host, path string, queryParams url.Values) string { if p == nil { // no paging. return "" } - // Check length before - // adding boundary params. - old := len(queryParams) + if queryParams == nil { + // Allocate new query parameters. + queryParams = make(url.Values) + } - if minParam := p.Min.Query(); minParam != "" { + if p.Min.Value != "" { // A page-minimum query parameter is available. - queryParams = append(queryParams, minParam) + queryParams.Add(p.Min.Name, p.Min.Value) } - if maxParam := p.Max.Query(); maxParam != "" { + if p.Max.Value != "" { // A page-maximum query parameter is available. - queryParams = append(queryParams, maxParam) - } - - if len(queryParams) == old { - // No page boundaries. - return "" + queryParams.Add(p.Max.Name, p.Max.Value) } if p.Limit > 0 { - // Build limit key-value query parameter. - param := "limit=" + strconv.Itoa(p.Limit) - - // Append `limit=$value` query parameter. - queryParams = append(queryParams, param) + // A page limit query parameter is available. + queryParams.Add("limit", strconv.Itoa(p.Limit)) } - // Join collected params into query str. - query := strings.Join(queryParams, "&") - // Build URL string. return (&url.URL{ Scheme: proto, Host: host, Path: path, - RawQuery: query, + RawQuery: queryParams.Encode(), }).String() } diff --git a/internal/paging/page_test.go b/internal/paging/page_test.go index 419b9ea44..01cc74d9f 100644 --- a/internal/paging/page_test.go +++ b/internal/paging/page_test.go @@ -97,7 +97,7 @@ var cases = []Case{ // Return page and expected IDs. return ids, &paging.Page{ - Min: paging.MinID(minID, ""), + Min: paging.MinID(minID), Max: paging.MaxID(maxID), }, expect }), @@ -129,7 +129,7 @@ var cases = []Case{ // Return page and expected IDs. return ids, &paging.Page{ - Min: paging.MinID(minID, ""), + Min: paging.MinID(minID), Max: paging.MaxID(maxID), Limit: limit, }, expect @@ -156,7 +156,7 @@ var cases = []Case{ // Return page and expected IDs. return ids, &paging.Page{ - Min: paging.MinID(minID, ""), + Min: paging.MinID(minID), Max: paging.MaxID(maxID), Limit: len(ids) * 2, }, expect @@ -182,7 +182,7 @@ var cases = []Case{ // Return page and expected IDs. return ids, &paging.Page{ - Min: paging.MinID("", sinceID), + Min: paging.SinceID(sinceID), Max: paging.MaxID(maxID), }, expect }), @@ -225,7 +225,7 @@ var cases = []Case{ // Return page and expected IDs. return ids, &paging.Page{ - Min: paging.MinID("", sinceID), + Min: paging.SinceID(sinceID), }, expect }), CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) { @@ -247,7 +247,7 @@ var cases = []Case{ // Return page and expected IDs. return ids, &paging.Page{ - Min: paging.MinID(minID, ""), + Min: paging.MinID(minID), }, expect }), } diff --git a/internal/paging/parse.go b/internal/paging/parse.go index 55ebef7f5..ce6391708 100644 --- a/internal/paging/parse.go +++ b/internal/paging/parse.go @@ -30,9 +30,9 @@ import ( // While conversely, a zero default limit will not enforce paging, returning a nil page value. func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) { // Extract request query params. - sinceID := c.Query("since_id") - minID := c.Query("min_id") - maxID := c.Query("max_id") + sinceID, haveSince := c.GetQuery("since_id") + minID, haveMin := c.GetQuery("min_id") + maxID, haveMax := c.GetQuery("max_id") // Extract request limit parameter. limit, errWithCode := ParseLimit(c, min, max, _default) @@ -40,20 +40,38 @@ func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCo return nil, errWithCode } - if sinceID == "" && - minID == "" && - maxID == "" && - limit == 0 { + switch { + case haveMin: + // A min_id was supplied, even if the value + // itself is empty. This indicates ASC order. + return &Page{ + Min: MinID(minID), + Max: MaxID(maxID), + Limit: limit, + }, nil + + case haveMax || haveSince: + // A max_id or since_id was supplied, even if the + // value itself is empty. This indicates DESC order. + return &Page{ + Min: SinceID(sinceID), + Max: MaxID(maxID), + Limit: limit, + }, nil + + case limit == 0: // No ID paging params provided, and no default // limit value which indicates paging not enforced. return nil, nil - } - return &Page{ - Min: MinID(minID, sinceID), - Max: MaxID(maxID), - Limit: limit, - }, nil + default: + // only limit. + return &Page{ + Min: SinceID(""), + Max: MaxID(""), + Limit: limit, + }, nil + } } // ParseShortcodeDomainPage parses an emoji shortcode domain Page from a request context, returning BadRequest @@ -62,8 +80,8 @@ func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCo // a zero default limit will not enforce paging, returning a nil page value. func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) { // Extract request query parameters. - minShortcode := c.Query("min_shortcode_domain") - maxShortcode := c.Query("max_shortcode_domain") + minShortcode, haveMin := c.GetQuery("min_shortcode_domain") + maxShortcode, haveMax := c.GetQuery("max_shortcode_domain") // Extract request limit parameter. limit, errWithCode := ParseLimit(c, min, max, _default) @@ -71,8 +89,8 @@ func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gt return nil, errWithCode } - if minShortcode == "" && - maxShortcode == "" && + if !haveMin && + !haveMax && limit == 0 { // No ID paging params provided, and no default // limit value which indicates paging not enforced. @@ -89,7 +107,10 @@ func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gt // ParseLimit parses the limit query parameter from a request context, returning BadRequest on error parsing and _default if zero limit given. func ParseLimit(c *gin.Context, min, max, _default int) (int, gtserror.WithCode) { // Get limit query param. - str := c.Query("limit") + str, ok := c.GetQuery("limit") + if !ok { + return _default, nil + } // Attempt to parse limit int. i, err := strconv.Atoi(str) diff --git a/internal/paging/response.go b/internal/paging/response.go index 498b42d34..71b0cf213 100644 --- a/internal/paging/response.go +++ b/internal/paging/response.go @@ -18,6 +18,7 @@ package paging import ( + "net/url" "strings" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -35,18 +36,13 @@ type ResponseParams struct { Path string // path to use for next/prev queries in the link header Next *Page // page details for the next page Prev *Page // page details for the previous page - Query []string // any extra query parameters to provide in the link header, should be in the format 'example=value' + Query url.Values // any extra query parameters to provide in the link header, should be in the format 'example=value' } // PackageResponse is a convenience function for returning // a bunch of pageable items (notifications, statuses, etc), as well // as a Link header to inform callers of where to find next/prev items. func PackageResponse(params ResponseParams) *apimodel.PageableResponse { - if len(params.Items) == 0 { - // No items to page through. - return EmptyResponse() - } - var ( // Extract paging params. nextPg = params.Next diff --git a/internal/paging/response_test.go b/internal/paging/response_test.go index 8eca2a601..b4b7d6058 100644 --- a/internal/paging/response_test.go +++ b/internal/paging/response_test.go @@ -42,9 +42,9 @@ func (suite *PagingSuite) TestPagingStandard() { resp := paging.PackageResponse(params) suite.Equal(make([]interface{}, 10, 10), resp.Items) - suite.Equal(`; rel="next", ; rel="prev"`, resp.LinkHeader) - suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10`, resp.NextLink) - suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10`, resp.PrevLink) + suite.Equal(`; rel="next", ; rel="prev"`, resp.LinkHeader) + suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink) + suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink) } func (suite *PagingSuite) TestPagingNoLimit() { @@ -77,9 +77,9 @@ func (suite *PagingSuite) TestPagingNoNextID() { resp := paging.PackageResponse(params) suite.Equal(make([]interface{}, 10, 10), resp.Items) - suite.Equal(`; rel="prev"`, resp.LinkHeader) + suite.Equal(`; rel="prev"`, resp.LinkHeader) suite.Equal(``, resp.NextLink) - suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10`, resp.PrevLink) + suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink) } func (suite *PagingSuite) TestPagingNoPrevID() { @@ -94,27 +94,11 @@ func (suite *PagingSuite) TestPagingNoPrevID() { resp := paging.PackageResponse(params) suite.Equal(make([]interface{}, 10, 10), resp.Items) - suite.Equal(`; rel="next"`, resp.LinkHeader) - suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10`, resp.NextLink) + suite.Equal(`; rel="next"`, resp.LinkHeader) + suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink) suite.Equal(``, resp.PrevLink) } -func (suite *PagingSuite) TestPagingNoItems() { - config.SetHost("example.org") - - params := paging.ResponseParams{ - Next: nextPage("01H11KA1DM2VH3747YDE7FV5HN", 10), - Prev: prevPage("01H11KBBVRRDYYC5KEPME1NP5R", 10), - } - - resp := paging.PackageResponse(params) - - suite.Empty(resp.Items) - suite.Empty(resp.LinkHeader) - suite.Empty(resp.NextLink) - suite.Empty(resp.PrevLink) -} - func TestPagingSuite(t *testing.T) { suite.Run(t, &PagingSuite{}) } @@ -128,7 +112,7 @@ func nextPage(id string, limit int) *paging.Page { func prevPage(id string, limit int) *paging.Page { return &paging.Page{ - Min: paging.MinID(id, ""), + Min: paging.MinID(id), Limit: limit, } } diff --git a/internal/paging/util.go b/internal/paging/util.go index d9adb9cbf..dd941dd88 100644 --- a/internal/paging/util.go +++ b/internal/paging/util.go @@ -41,9 +41,3 @@ func Reverse(in []string) []string { return in } - -// zero is a shorthand to check a generic value is its zero value. -func zero[T comparable](t T) bool { - var z T - return t == z -} diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go index 7bef8b0c5..a32a73ac1 100644 --- a/internal/processing/account/account.go +++ b/internal/processing/account/account.go @@ -22,6 +22,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/processing/common" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -32,6 +33,9 @@ import ( // // It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc. type Processor struct { + // common processor logic + c *common.Processor + state *state.State tc typeutils.TypeConverter mediaManager *media.Manager @@ -44,6 +48,7 @@ type Processor struct { // New returns a new account processor. func New( + common *common.Processor, state *state.State, tc typeutils.TypeConverter, mediaManager *media.Manager, @@ -53,6 +58,7 @@ func New( parseMention gtsmodel.ParseMentionFunc, ) Processor { return Processor{ + c: common, state: state, tc: tc, mediaManager: mediaManager, diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go index 4ba7de16e..2e4a64844 100644 --- a/internal/processing/account/account_test.go +++ b/internal/processing/account/account_test.go @@ -30,6 +30,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing/account" + "github.com/superseriousbusiness/gotosocial/internal/processing/common" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" @@ -113,7 +114,8 @@ func (suite *AccountStandardTestSuite) SetupTest() { suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) filter := visibility.NewFilter(&suite.state) - suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) + common := common.New(&suite.state, suite.tc, suite.federator, filter) + suite.accountProcessor = account.New(&common, &suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") } diff --git a/internal/processing/account/block.go b/internal/processing/account/block.go index 1ec31a753..270048100 100644 --- a/internal/processing/account/block.go +++ b/internal/processing/account/block.go @@ -28,8 +28,11 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/uris" + "github.com/superseriousbusiness/gotosocial/internal/util" ) // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local. @@ -128,6 +131,53 @@ func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel return p.RelationshipGet(ctx, requestingAccount, targetAccountID) } +// BlocksGet ... +func (p *Processor) BlocksGet( + ctx context.Context, + requestingAccount *gtsmodel.Account, + page *paging.Page, +) (*apimodel.PageableResponse, gtserror.WithCode) { + blocks, err := p.state.DB.GetAccountBlocks(ctx, + requestingAccount.ID, + page, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorInternalError(err) + } + + // Check for empty response. + count := len(blocks) + if len(blocks) == 0 { + return util.EmptyPageableResponse(), nil + } + + items := make([]interface{}, 0, count) + + for _, block := range blocks { + // Convert target account to frontend API model. (target will never be nil) + account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount) + if err != nil { + log.Errorf(ctx, "error converting account to public api account: %v", err) + continue + } + + // Append target to return items. + items = append(items, account) + } + + // Get the lowest and highest + // ID values, used for paging. + lo := blocks[count-1].ID + hi := blocks[0].ID + + return paging.PackageResponse(paging.ResponseParams{ + Items: items, + Path: "/api/v1/blocks", + Next: page.Next(lo, hi), + Prev: page.Prev(lo, hi), + }), nil +} + func (p *Processor) getBlockTarget(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*gtsmodel.Account, *gtsmodel.Block, gtserror.WithCode) { // Account should not block or unblock itself. if requestingAccount.ID == targetAccountID { diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index da13eb20e..e89ebf13f 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -160,7 +160,7 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account * // - Follow requests created by account. func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.Account) error { // Delete follows targeting this account. - followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID) + followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID, nil) if err != nil && !errors.Is(err, db.ErrNoEntries) { return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err) } @@ -172,7 +172,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. } // Delete follow requests targeting this account. - followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID) + followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID, nil) if err != nil && !errors.Is(err, db.ErrNoEntries) { return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err) } @@ -193,7 +193,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. ) // Delete follows originating from this account. - following, err := p.state.DB.GetAccountFollows(ctx, account.ID) + following, err := p.state.DB.GetAccountFollows(ctx, account.ID, nil) if err != nil && !errors.Is(err, db.ErrNoEntries) { return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err) } @@ -211,7 +211,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. } // Delete follow requests originating from this account. - followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID) + followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID, nil) if err != nil && !errors.Is(err, db.ErrNoEntries) { return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err) } diff --git a/internal/processing/account/follow.go b/internal/processing/account/follow.go index 1aed92e75..8006f8d79 100644 --- a/internal/processing/account/follow.go +++ b/internal/processing/account/follow.go @@ -20,7 +20,6 @@ package account import ( "context" "errors" - "fmt" "github.com/superseriousbusiness/gotosocial/internal/ap" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" @@ -35,7 +34,7 @@ import ( // FollowCreate handles a follow request to an account, either remote or local. func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { - targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount.ID, form.ID) + targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount, form.ID) if errWithCode != nil { return nil, errWithCode } @@ -46,7 +45,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode requestingAccount.ID, targetAccount.ID, ); err != nil && !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("FollowCreate: db error checking existing follow: %w", err) + err = gtserror.Newf("db error checking existing follow: %w", err) return nil, gtserror.NewErrorInternalError(err) } else if follow != nil { // Already follows, update if necessary + return relationship. @@ -66,7 +65,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode requestingAccount.ID, targetAccount.ID, ); err != nil && !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("FollowCreate: db error checking existing follow request: %w", err) + err = gtserror.Newf("db error checking existing follow request: %w", err) return nil, gtserror.NewErrorInternalError(err) } else if followRequest != nil { // Already requested, update if necessary + return relationship. @@ -100,7 +99,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode } if err := p.state.DB.PutFollowRequest(ctx, fr); err != nil { - err = fmt.Errorf("FollowCreate: error creating follow request in db: %s", err) + err = gtserror.Newf("error creating follow request in db: %s", err) return nil, gtserror.NewErrorInternalError(err) } @@ -112,7 +111,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode // Because we know the requestingAccount is also // local, we don't need to federate the accept out. if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { - err = fmt.Errorf("FollowCreate: error accepting follow request for local unlocked account: %w", err) + err = gtserror.Newf("error accepting follow request for local unlocked account: %w", err) return nil, gtserror.NewErrorInternalError(err) } } else if targetAccount.IsRemote() { @@ -132,7 +131,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode // FollowRemove handles the removal of a follow/follow request to an account, either remote or local. func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount.ID, targetAccountID) + targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount, targetAccountID) if errWithCode != nil { return nil, errWithCode } @@ -140,7 +139,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode // Unfollow and deal with side effects. msgs, err := p.unfollow(ctx, requestingAccount, targetAccount) if err != nil { - return nil, gtserror.NewErrorNotFound(fmt.Errorf("FollowRemove: account %s not found in the db: %s", targetAccountID, err)) + return nil, gtserror.NewErrorNotFound(gtserror.Newf("account %s not found in the db: %s", targetAccountID, err)) } // Batch queue accreted client api messages. @@ -166,7 +165,6 @@ func (p *Processor) updateFollow( currentNotify *bool, update func(...string) error, ) (*apimodel.Relationship, gtserror.WithCode) { - if form.Reblogs == nil && form.Notify == nil { // There's nothing to update. return p.RelationshipGet(ctx, requestingAccount, form.ID) @@ -192,7 +190,7 @@ func (p *Processor) updateFollow( } if err := update(columns...); err != nil { - err = fmt.Errorf("updateFollow: error updating existing follow (request): %w", err) + err = gtserror.Newf("error updating existing follow (request): %w", err) return nil, gtserror.NewErrorInternalError(err) } @@ -201,38 +199,23 @@ func (p *Processor) updateFollow( // getFollowTarget is a convenience function which: // - Checks if account is trying to follow/unfollow itself. -// - Returns not found if there's a block in place between accounts. +// - Returns not found if target should not be visible to requester. // - Returns target account according to its id. -func (p *Processor) getFollowTarget(ctx context.Context, requestingAccountID string, targetAccountID string) (*gtsmodel.Account, gtserror.WithCode) { +func (p *Processor) getFollowTarget(ctx context.Context, requester *gtsmodel.Account, targetID string) (*gtsmodel.Account, gtserror.WithCode) { + // Check for requester. + if requester == nil { + err := errors.New("no authorized user") + return nil, gtserror.NewErrorUnauthorized(err) + } + // Account can't follow or unfollow itself. - if requestingAccountID == targetAccountID { + if requester.ID == targetID { err := errors.New("account can't follow or unfollow itself") return nil, gtserror.NewErrorNotAcceptable(err) } - // Do nothing if a block exists in either direction between accounts. - if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, targetAccountID); err != nil { - err = fmt.Errorf("db error checking block between accounts: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } else if blocked { - err = errors.New("block exists between accounts") - return nil, gtserror.NewErrorNotFound(err) - } - - // Ensure target account retrievable. - targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID) - if err != nil { - if !errors.Is(err, db.ErrNoEntries) { - // Real db error. - err = fmt.Errorf("db error looking for target account %s: %w", targetAccountID, err) - return nil, gtserror.NewErrorInternalError(err) - } - // Account not found. - err = fmt.Errorf("target account %s not found in the db", targetAccountID) - return nil, gtserror.NewErrorNotFound(err, err.Error()) - } - - return targetAccount, nil + // Fetch the target account for requesting user account. + return p.c.GetVisibleTargetAccount(ctx, requester, targetID) } // unfollow is a convenience function for having requesting account @@ -248,7 +231,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac // Get follow from requesting account to target account. follow, err := p.state.DB.GetFollow(ctx, requestingAccount.ID, targetAccount.ID) if err != nil && !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("unfollow: error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) + err = gtserror.Newf("error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) return nil, err } @@ -257,7 +240,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac err = p.state.DB.DeleteFollowByID(ctx, follow.ID) if err != nil { if !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("unfollow: error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) + err = gtserror.Newf("error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) return nil, err } @@ -284,7 +267,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac // Get follow request from requesting account to target account. followReq, err := p.state.DB.GetFollowRequest(ctx, requestingAccount.ID, targetAccount.ID) if err != nil && !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("unfollow: error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) + err = gtserror.Newf("error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) return nil, err } @@ -293,7 +276,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac err = p.state.DB.DeleteFollowRequestByID(ctx, followReq.ID) if err != nil { if !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("unfollow: error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) + err = gtserror.Newf("error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) return nil, err } diff --git a/internal/processing/account/follow_request.go b/internal/processing/account/follow_request.go new file mode 100644 index 000000000..c054637c8 --- /dev/null +++ b/internal/processing/account/follow_request.go @@ -0,0 +1,119 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package account + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/ap" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/paging" +) + +// FollowRequestAccept handles the accepting of a follow request from the sourceAccountID to the requestingAccount (the currently authorized account). +func (p *Processor) FollowRequestAccept(ctx context.Context, requestingAccount *gtsmodel.Account, sourceAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + follow, err := p.state.DB.AcceptFollowRequest(ctx, sourceAccountID, requestingAccount.ID) + if err != nil { + return nil, gtserror.NewErrorNotFound(err) + } + + if follow.Account != nil { + // Only enqueue work in the case we have a request creating account stored. + // NOTE: due to how AcceptFollowRequest works, the inverse shouldn't be possible. + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ + APObjectType: ap.ActivityFollow, + APActivityType: ap.ActivityAccept, + GTSModel: follow, + OriginAccount: follow.Account, + TargetAccount: follow.TargetAccount, + }) + } + + return p.RelationshipGet(ctx, requestingAccount, sourceAccountID) +} + +// FollowRequestReject handles the rejection of a follow request from the sourceAccountID to the requestingAccount (the currently authorized account). +func (p *Processor) FollowRequestReject(ctx context.Context, requestingAccount *gtsmodel.Account, sourceAccountID string) (*apimodel.Relationship, gtserror.WithCode) { + followRequest, err := p.state.DB.GetFollowRequest(ctx, sourceAccountID, requestingAccount.ID) + if err != nil { + return nil, gtserror.NewErrorNotFound(err) + } + + err = p.state.DB.RejectFollowRequest(ctx, sourceAccountID, requestingAccount.ID) + if err != nil { + return nil, gtserror.NewErrorNotFound(err) + } + + if followRequest.Account != nil { + // Only enqueue work in the case we have a request creating account stored. + // NOTE: due to how GetFollowRequest works, the inverse shouldn't be possible. + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ + APObjectType: ap.ActivityFollow, + APActivityType: ap.ActivityReject, + GTSModel: followRequest, + OriginAccount: followRequest.Account, + TargetAccount: followRequest.TargetAccount, + }) + } + + return p.RelationshipGet(ctx, requestingAccount, sourceAccountID) +} + +// FollowRequestsGet fetches a list of the accounts that are follow requesting the given requestingAccount (the currently authorized account). +func (p *Processor) FollowRequestsGet(ctx context.Context, requestingAccount *gtsmodel.Account, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) { + // Fetch follow requests targeting the given requesting account model. + followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, requestingAccount.ID, page) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorInternalError(err) + } + + // Check for empty response. + count := len(followRequests) + if count == 0 { + return paging.EmptyResponse(), nil + } + + // Get the lowest and highest + // ID values, used for paging. + lo := followRequests[count-1].ID + hi := followRequests[0].ID + + // Func to fetch follow source at index. + getIdx := func(i int) *gtsmodel.Account { + return followRequests[i].Account + } + + // Get a filtered slice of public API account models. + items := p.c.GetVisibleAPIAccountsPaged(ctx, + requestingAccount, + getIdx, + count, + ) + + return paging.PackageResponse(paging.ResponseParams{ + Items: items, + Path: "/api/v1/follow_requests", + Next: page.Next(lo, hi), + Prev: page.Prev(lo, hi), + }), nil +} diff --git a/internal/processing/account/relationships.go b/internal/processing/account/relationships.go index d12d989ef..58c98f3ba 100644 --- a/internal/processing/account/relationships.go +++ b/internal/processing/account/relationships.go @@ -20,128 +20,120 @@ package account import ( "context" "errors" - "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/paging" ) // FollowersGet fetches a list of the target account's followers. -func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil { - err = fmt.Errorf("FollowersGet: db error checking block: %w", err) +func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) { + // Fetch target account to check it exists, and visibility of requester->target. + _, errWithCode := p.c.GetVisibleTargetAccount(ctx, requestingAccount, targetAccountID) + if errWithCode != nil { + return nil, errWithCode + } + + follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID, page) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err = gtserror.Newf("db error getting followers: %w", err) return nil, gtserror.NewErrorInternalError(err) - } else if blocked { - err = errors.New("FollowersGet: block exists between accounts") - return nil, gtserror.NewErrorNotFound(err) } - follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID) - if err != nil { - if !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("FollowersGet: db error getting followers: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - return []apimodel.Account{}, nil + // Check for empty response. + count := len(follows) + if count == 0 { + return paging.EmptyResponse(), nil } - return p.accountsFromFollows(ctx, follows, requestingAccount.ID) + // Get the lowest and highest + // ID values, used for paging. + lo := follows[count-1].ID + hi := follows[0].ID + + // Func to fetch follow source at index. + getIdx := func(i int) *gtsmodel.Account { + return follows[i].Account + } + + // Get a filtered slice of public API account models. + items := p.c.GetVisibleAPIAccountsPaged(ctx, + requestingAccount, + getIdx, + len(follows), + ) + + return paging.PackageResponse(paging.ResponseParams{ + Items: items, + Path: "/api/v1/accounts/" + targetAccountID + "/followers", + Next: page.Next(lo, hi), + Prev: page.Prev(lo, hi), + }), nil } // FollowingGet fetches a list of the accounts that target account is following. -func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil { - err = fmt.Errorf("FollowingGet: db error checking block: %w", err) +func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) { + // Fetch target account to check it exists, and visibility of requester->target. + _, errWithCode := p.c.GetVisibleTargetAccount(ctx, requestingAccount, targetAccountID) + if errWithCode != nil { + return nil, errWithCode + } + + // Fetch known accounts that follow given target account ID. + follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID, page) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err = gtserror.Newf("db error getting followers: %w", err) return nil, gtserror.NewErrorInternalError(err) - } else if blocked { - err = errors.New("FollowingGet: block exists between accounts") - return nil, gtserror.NewErrorNotFound(err) } - follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID) - if err != nil { - if !errors.Is(err, db.ErrNoEntries) { - err = fmt.Errorf("FollowingGet: db error getting followers: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - return []apimodel.Account{}, nil + // Check for empty response. + count := len(follows) + if count == 0 { + return paging.EmptyResponse(), nil } - return p.targetAccountsFromFollows(ctx, follows, requestingAccount.ID) + // Get the lowest and highest + // ID values, used for paging. + lo := follows[count-1].ID + hi := follows[0].ID + + // Func to fetch follow source at index. + getIdx := func(i int) *gtsmodel.Account { + return follows[i].TargetAccount + } + + // Get a filtered slice of public API account models. + items := p.c.GetVisibleAPIAccountsPaged(ctx, + requestingAccount, + getIdx, + len(follows), + ) + + return paging.PackageResponse(paging.ResponseParams{ + Items: items, + Path: "/api/v1/accounts/" + targetAccountID + "/following", + Next: page.Next(lo, hi), + Prev: page.Prev(lo, hi), + }), nil } // RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account. func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { if requestingAccount == nil { - return nil, gtserror.NewErrorForbidden(errors.New("not authed")) + return nil, gtserror.NewErrorForbidden(gtserror.New("not authed")) } gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID) if err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err)) + return nil, gtserror.NewErrorInternalError(gtserror.Newf("error getting relationship: %s", err)) } r, err := p.tc.RelationshipToAPIRelationship(ctx, gtsR) if err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting relationship: %s", err)) + return nil, gtserror.NewErrorInternalError(gtserror.Newf("error converting relationship: %s", err)) } return r, nil } - -func (p *Processor) accountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) { - accounts := make([]apimodel.Account, 0, len(follows)) - for _, follow := range follows { - if follow.Account == nil { - // No account set for some reason; just skip. - log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated account") - continue - } - - if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.AccountID); err != nil { - err = fmt.Errorf("accountsFromFollows: db error checking block: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } else if blocked { - continue - } - - account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.Account) - if err != nil { - err = fmt.Errorf("accountsFromFollows: error converting account to api account: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - accounts = append(accounts, *account) - } - return accounts, nil -} - -func (p *Processor) targetAccountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) { - accounts := make([]apimodel.Account, 0, len(follows)) - for _, follow := range follows { - if follow.TargetAccount == nil { - // No account set for some reason; just skip. - log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated target account") - continue - } - - if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.TargetAccountID); err != nil { - err = fmt.Errorf("targetAccountsFromFollows: db error checking block: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } else if blocked { - continue - } - - account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.TargetAccount) - if err != nil { - err = fmt.Errorf("targetAccountsFromFollows: error converting account to api account: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - accounts = append(accounts, *account) - } - return accounts, nil -} diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go deleted file mode 100644 index 014b6af21..000000000 --- a/internal/processing/blocks.go +++ /dev/null @@ -1,86 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package processing - -import ( - "context" - "errors" - - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/paging" - "github.com/superseriousbusiness/gotosocial/internal/util" -) - -// BlocksGet ... -func (p *Processor) BlocksGet( - ctx context.Context, - requestingAccount *gtsmodel.Account, - page *paging.Page, -) (*apimodel.PageableResponse, gtserror.WithCode) { - blocks, err := p.state.DB.GetAccountBlocks(ctx, - requestingAccount.ID, - page, - ) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, gtserror.NewErrorInternalError(err) - } - - // Check for zero length. - count := len(blocks) - if len(blocks) == 0 { - return util.EmptyPageableResponse(), nil - } - - var ( - items = make([]interface{}, 0, count) - - // Set next + prev values before API converting - // so the caller can still page even on error. - nextMaxIDValue = blocks[count-1].ID - prevMinIDValue = blocks[0].ID - ) - - for _, block := range blocks { - if block.TargetAccount == nil { - // All models should be populated at this point. - log.Warnf(ctx, "block target account was nil: %v", err) - continue - } - - // Convert target account to frontend API model. - account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount) - if err != nil { - log.Errorf(ctx, "error converting account to public api account: %v", err) - continue - } - - // Append target to return items. - items = append(items, account) - } - - return paging.PackageResponse(paging.ResponseParams{ - Items: items, - Path: "/api/v1/blocks", - Next: page.Next(nextMaxIDValue), - Prev: page.Prev(prevMinIDValue), - }), nil -} diff --git a/internal/processing/common/account.go.go b/internal/processing/common/account.go.go new file mode 100644 index 000000000..06e87fa0e --- /dev/null +++ b/internal/processing/common/account.go.go @@ -0,0 +1,238 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "context" + "errors" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" +) + +// GetTargetAccountBy fetches the target account with db load function, given the authorized (or, nil) requester's +// account. This returns an approprate gtserror.WithCode accounting (ha) for not found and visibility to requester. +func (p *Processor) GetTargetAccountBy( + ctx context.Context, + requester *gtsmodel.Account, + getTargetFromDB func() (*gtsmodel.Account, error), +) ( + account *gtsmodel.Account, + visible bool, + errWithCode gtserror.WithCode, +) { + // Fetch the target account from db. + target, err := getTargetFromDB() + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, false, gtserror.NewErrorInternalError(err) + } + + if target == nil { + // DB loader could not find account in database. + err := errors.New("target account not found") + return nil, false, gtserror.NewErrorNotFound(err) + } + + // Check whether target account is visible to requesting account. + visible, err = p.filter.AccountVisible(ctx, requester, target) + if err != nil { + return nil, false, gtserror.NewErrorInternalError(err) + } + + if requester != nil && visible { + // Ensure the account is up-to-date. + p.federator.RefreshAccountAsync(ctx, + requester.Username, + target, + nil, + false, + ) + } + + return target, visible, nil +} + +// GetTargetAccountByID is a call-through to GetTargetAccountBy() using the db GetAccountByID() function. +func (p *Processor) GetTargetAccountByID( + ctx context.Context, + requester *gtsmodel.Account, + targetID string, +) ( + account *gtsmodel.Account, + visible bool, + errWithCode gtserror.WithCode, +) { + return p.GetTargetAccountBy(ctx, requester, func() (*gtsmodel.Account, error) { + return p.state.DB.GetAccountByID(ctx, targetID) + }) +} + +// GetVisibleTargetAccount calls GetTargetAccountByID(), +// but converts a non-visible result to not-found error. +func (p *Processor) GetVisibleTargetAccount( + ctx context.Context, + requester *gtsmodel.Account, + targetID string, +) ( + account *gtsmodel.Account, + errWithCode gtserror.WithCode, +) { + // Fetch the target account by ID from the database. + target, visible, errWithCode := p.GetTargetAccountByID(ctx, + requester, + targetID, + ) + if errWithCode != nil { + return nil, errWithCode + } + + if !visible { + // Pretend account doesn't exist if not visible. + err := errors.New("target account not found") + return nil, gtserror.NewErrorNotFound(err) + } + + return target, nil +} + +// GetAPIAccount fetches the appropriate API account model depending on whether requester = target. +func (p *Processor) GetAPIAccount( + ctx context.Context, + requester *gtsmodel.Account, + target *gtsmodel.Account, +) ( + apiAcc *apimodel.Account, + errWithCode gtserror.WithCode, +) { + var err error + + if requester != nil && requester.ID == target.ID { + // Only return sensitive account model _if_ requester = target. + apiAcc, err = p.converter.AccountToAPIAccountSensitive(ctx, target) + } else { + // Else, fall back to returning the public account model. + apiAcc, err = p.converter.AccountToAPIAccountPublic(ctx, target) + } + + if err != nil { + err := gtserror.Newf("error converting account: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + return apiAcc, nil +} + +// GetAPIAccountBlocked fetches the limited "blocked" account model for given target. +func (p *Processor) GetAPIAccountBlocked( + ctx context.Context, + targetAcc *gtsmodel.Account, +) ( + apiAcc *apimodel.Account, + errWithCode gtserror.WithCode, +) { + apiAccount, err := p.converter.AccountToAPIAccountBlocked(ctx, targetAcc) + if err != nil { + err = gtserror.Newf("error converting account: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + return apiAccount, nil +} + +// GetVisibleAPIAccounts converts an array of gtsmodel.Accounts (inputted by next function) into +// public API model accounts, checking first for visibility. Please note that all errors will be +// logged at ERROR level, but will not be returned. Callers are likely to run into show-stopping +// errors in the lead-up to this function, whereas calling this should not be a show-stopper. +func (p *Processor) GetVisibleAPIAccounts( + ctx context.Context, + requester *gtsmodel.Account, + next func(int) *gtsmodel.Account, + length int, +) []*apimodel.Account { + return p.getVisibleAPIAccounts(ctx, 3, requester, next, length) +} + +// GetVisibleAPIAccountsPaged is functionally equivalent to GetVisibleAPIAccounts(), +// except the accounts are returned as a converted slice of accounts as interface{}. +func (p *Processor) GetVisibleAPIAccountsPaged( + ctx context.Context, + requester *gtsmodel.Account, + next func(int) *gtsmodel.Account, + length int, +) []interface{} { + accounts := p.getVisibleAPIAccounts(ctx, 3, requester, next, length) + if len(accounts) == 0 { + return nil + } + items := make([]interface{}, len(accounts)) + for i, account := range accounts { + items[i] = account + } + return items +} + +func (p *Processor) getVisibleAPIAccounts( + ctx context.Context, + calldepth int, // used to skip wrapping func above these's names + requester *gtsmodel.Account, + next func(int) *gtsmodel.Account, + length int, +) []*apimodel.Account { + // Start new log entry with + // the above calling func's name. + l := log. + WithContext(ctx). + WithField("caller", log.Caller(calldepth+1)) + + // Preallocate slice according to expected length. + accounts := make([]*apimodel.Account, 0, length) + + for i := 0; i < length; i++ { + // Get next account. + account := next(i) + if account == nil { + continue + } + + // Check whether this account is visible to requesting account. + visible, err := p.filter.AccountVisible(ctx, requester, account) + if err != nil { + l.Errorf("error checking account visibility: %v", err) + continue + } + + if !visible { + // Not visible to requester. + continue + } + + // Convert the account to a public API model representation. + apiAcc, err := p.converter.AccountToAPIAccountPublic(ctx, account) + if err != nil { + l.Errorf("error converting account: %v", err) + continue + } + + // Append API model to return slice. + accounts = append(accounts, apiAcc) + } + + return accounts +} diff --git a/internal/processing/common/common.go b/internal/processing/common/common.go new file mode 100644 index 000000000..53c298579 --- /dev/null +++ b/internal/processing/common/common.go @@ -0,0 +1,50 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/internal/visibility" +) + +// Processor provides a processor with logic +// common to multiple logical domains of the +// processing subsection of the codebase. +type Processor struct { + state *state.State + converter typeutils.TypeConverter + federator federation.Federator + filter *visibility.Filter +} + +// New returns a new Processor instance. +func New( + state *state.State, + converter typeutils.TypeConverter, + federator federation.Federator, + filter *visibility.Filter, +) Processor { + return Processor{ + state: state, + converter: converter, + federator: federator, + filter: filter, + } +} diff --git a/internal/processing/common/status.go b/internal/processing/common/status.go new file mode 100644 index 000000000..fb480ec7e --- /dev/null +++ b/internal/processing/common/status.go @@ -0,0 +1,248 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "context" + "errors" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" +) + +// GetTargetStatusBy fetches the target status with db load function, given the authorized (or, nil) requester's +// account. This returns an approprate gtserror.WithCode accounting for not found and visibility to requester. +func (p *Processor) GetTargetStatusBy( + ctx context.Context, + requester *gtsmodel.Account, + getTargetFromDB func() (*gtsmodel.Status, error), +) ( + status *gtsmodel.Status, + visible bool, + errWithCode gtserror.WithCode, +) { + // Fetch the target status from db. + target, err := getTargetFromDB() + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, false, gtserror.NewErrorInternalError(err) + } + + if target == nil { + // DB loader could not find status in database. + err := errors.New("target status not found") + return nil, false, gtserror.NewErrorNotFound(err) + } + + // Check whether target status is visible to requesting account. + visible, err = p.filter.StatusVisible(ctx, requester, target) + if err != nil { + return nil, false, gtserror.NewErrorInternalError(err) + } + + if requester != nil && visible { + // Ensure remote status is up-to-date. + p.federator.RefreshStatusAsync(ctx, + requester.Username, + target, + nil, + false, + ) + } + + return target, visible, nil +} + +// GetTargetStatusByID is a call-through to GetTargetStatus() using the db GetStatusByID() function. +func (p *Processor) GetTargetStatusByID( + ctx context.Context, + requester *gtsmodel.Account, + targetID string, +) ( + status *gtsmodel.Status, + visible bool, + errWithCode gtserror.WithCode, +) { + return p.GetTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) { + return p.state.DB.GetStatusByID(ctx, targetID) + }) +} + +// GetVisibleTargetStatus calls GetTargetStatusByID(), +// but converts a non-visible result to not-found error. +func (p *Processor) GetVisibleTargetStatus( + ctx context.Context, + requester *gtsmodel.Account, + targetID string, +) ( + status *gtsmodel.Status, + errWithCode gtserror.WithCode, +) { + // Fetch the target status by ID from the database. + target, visible, errWithCode := p.GetTargetStatusByID(ctx, + requester, + targetID, + ) + if errWithCode != nil { + return nil, errWithCode + } + + if !visible { + // Target should not be seen by requester. + err := errors.New("target status not found") + return nil, gtserror.NewErrorNotFound(err) + } + + return target, nil +} + +// GetAPIStatus fetches the appropriate API status model for target. +func (p *Processor) GetAPIStatus( + ctx context.Context, + requester *gtsmodel.Account, + target *gtsmodel.Status, +) ( + apiStatus *apimodel.Status, + errWithCode gtserror.WithCode, +) { + apiStatus, err := p.converter.StatusToAPIStatus(ctx, target, requester) + if err != nil { + err = gtserror.Newf("error converting status: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + return apiStatus, nil +} + +// GetVisibleAPIStatuses converts an array of gtsmodel.Status (inputted by next function) into +// API model statuses, checking first for visibility. Please note that all errors will be +// logged at ERROR level, but will not be returned. Callers are likely to run into show-stopping +// errors in the lead-up to this function, whereas calling this should not be a show-stopper. +func (p *Processor) GetVisibleAPIStatuses( + ctx context.Context, + requester *gtsmodel.Account, + next func(int) *gtsmodel.Status, + length int, +) []*apimodel.Status { + return p.getVisibleAPIStatuses(ctx, 3, requester, next, length) +} + +// GetVisibleAPIStatusesPaged is functionally equivalent to GetVisibleAPIStatuses(), +// except the statuses are returned as a converted slice of statuses as interface{}. +func (p *Processor) GetVisibleAPIStatusesPaged( + ctx context.Context, + requester *gtsmodel.Account, + next func(int) *gtsmodel.Status, + length int, +) []interface{} { + statuses := p.getVisibleAPIStatuses(ctx, 3, requester, next, length) + if len(statuses) == 0 { + return nil + } + items := make([]interface{}, len(statuses)) + for i, status := range statuses { + items[i] = status + } + return items +} + +func (p *Processor) getVisibleAPIStatuses( + ctx context.Context, + calldepth int, // used to skip wrapping func above these's names + requester *gtsmodel.Account, + next func(int) *gtsmodel.Status, + length int, +) []*apimodel.Status { + // Start new log entry with + // the above calling func's name. + l := log. + WithContext(ctx). + WithField("caller", log.Caller(calldepth+1)) + + // Preallocate slice according to expected length. + statuses := make([]*apimodel.Status, 0, length) + + for i := 0; i < length; i++ { + // Get next status. + status := next(i) + if status == nil { + continue + } + + // Check whether this status is visible to requesting account. + visible, err := p.filter.StatusVisible(ctx, requester, status) + if err != nil { + l.Errorf("error checking status visibility: %v", err) + continue + } + + if !visible { + // Not visible to requester. + continue + } + + // Convert the status to an API model representation. + apiStatus, err := p.converter.StatusToAPIStatus(ctx, status, requester) + if err != nil { + l.Errorf("error converting status: %v", err) + continue + } + + // Append API model to return slice. + statuses = append(statuses, apiStatus) + } + + return statuses +} + +// InvalidateTimelinedStatus is a shortcut function for invalidating the cached +// representation one status in the home timeline and all list timelines of the +// given accountID. It should only be called in cases where a status update +// does *not* need to be passed into the processor via the worker queue, since +// such invalidation will, in that case, be handled by the processor instead. +func (p *Processor) InvalidateTimelinedStatus(ctx context.Context, accountID string, statusID string) error { + // Get lists first + bail if this fails. + lists, err := p.state.DB.GetListsForAccountID(ctx, accountID) + if err != nil { + return gtserror.Newf("db error getting lists for account %s: %w", accountID, err) + } + + // Start new log entry with + // the above calling func's name. + l := log. + WithContext(ctx). + WithField("caller", log.Caller(3)). + WithField("accountID", accountID). + WithField("statusID", statusID) + + // Unprepare item from home + list timelines, just log + // if something goes wrong since this is not a showstopper. + + if err := p.state.Timelines.Home.UnprepareItem(ctx, accountID, statusID); err != nil { + l.Errorf("error unpreparing item from home timeline: %v", err) + } + + for _, list := range lists { + if err := p.state.Timelines.List.UnprepareItem(ctx, list.ID, statusID); err != nil { + l.Errorf("error unpreparing item from list timeline %s: %v", list.ID, err) + } + } + + return nil +} diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go deleted file mode 100644 index 6587b73bb..000000000 --- a/internal/processing/followrequest.go +++ /dev/null @@ -1,123 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package processing - -import ( - "context" - "errors" - - "github.com/superseriousbusiness/gotosocial/internal/ap" - apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/oauth" -) - -func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { - followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, gtserror.NewErrorInternalError(err) - } - - accts := make([]apimodel.Account, 0, len(followRequests)) - for _, followRequest := range followRequests { - if followRequest.Account == nil { - // The creator of the follow doesn't exist, - // just skip this one. - log.WithContext(ctx).WithField("followRequest", followRequest).Warn("follow request had no associated account") - continue - } - - apiAcct, err := p.tc.AccountToAPIAccountPublic(ctx, followRequest.Account) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - - accts = append(accts, *apiAcct) - } - - return accts, nil -} - -func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { - follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID) - if err != nil { - return nil, gtserror.NewErrorNotFound(err) - } - - if follow.Account == nil { - // The creator of the follow doesn't exist, - // so we can't do further processing. - log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated account") - return p.relationship(ctx, auth.Account.ID, accountID) - } - - p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ - APObjectType: ap.ActivityFollow, - APActivityType: ap.ActivityAccept, - GTSModel: follow, - OriginAccount: follow.Account, - TargetAccount: follow.TargetAccount, - }) - - return p.relationship(ctx, auth.Account.ID, accountID) -} - -func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { - followRequest, err := p.state.DB.GetFollowRequest(ctx, accountID, auth.Account.ID) - if err != nil { - return nil, gtserror.NewErrorNotFound(err) - } - - err = p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID) - if err != nil { - return nil, gtserror.NewErrorNotFound(err) - } - - if followRequest.Account == nil { - // The creator of the request doesn't exist, - // so we can't do further processing. - return p.relationship(ctx, auth.Account.ID, accountID) - } - - p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ - APObjectType: ap.ActivityFollow, - APActivityType: ap.ActivityReject, - GTSModel: followRequest, - OriginAccount: followRequest.Account, - TargetAccount: followRequest.TargetAccount, - }) - - return p.relationship(ctx, auth.Account.ID, accountID) -} - -func (p *Processor) relationship(ctx context.Context, accountID string, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { - relationship, err := p.state.DB.GetRelationship(ctx, accountID, targetAccountID) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - - apiRelationship, err := p.tc.RelationshipToAPIRelationship(ctx, relationship) - if err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - - return apiRelationship, nil -} diff --git a/internal/processing/followrequest_test.go b/internal/processing/followrequest_test.go index addb5052e..4c089be4a 100644 --- a/internal/processing/followrequest_test.go +++ b/internal/processing/followrequest_test.go @@ -30,35 +30,57 @@ import ( "github.com/superseriousbusiness/gotosocial/testrig" ) +// TODO: move this to the "internal/processing/account" pkg type FollowRequestTestSuite struct { ProcessingStandardTestSuite } func (suite *FollowRequestTestSuite) TestFollowRequestAccept() { - requestingAccount := suite.testAccounts["remote_account_2"] - targetAccount := suite.testAccounts["local_account_1"] + // The authed local account we are going to use for HTTP requests + requestingAccount := suite.testAccounts["local_account_1"] + + // The remote account whose follow request we are accepting + targetAccount := suite.testAccounts["remote_account_2"] // put a follow request in the database fr := >smodel.FollowRequest{ ID: "01FJ1S8DX3STJJ6CEYPMZ1M0R3", CreatedAt: time.Now(), UpdatedAt: time.Now(), - URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", requestingAccount.URI), - AccountID: requestingAccount.ID, - TargetAccountID: targetAccount.ID, + URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", targetAccount.URI), + AccountID: targetAccount.ID, + TargetAccountID: requestingAccount.ID, } err := suite.db.Put(context.Background(), fr) suite.NoError(err) - relationship, errWithCode := suite.processor.FollowRequestAccept(context.Background(), suite.testAutheds["local_account_1"], requestingAccount.ID) + relationship, errWithCode := suite.processor.Account().FollowRequestAccept( + context.Background(), + requestingAccount, + targetAccount.ID, + ) suite.NoError(errWithCode) - suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: true, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship) + suite.EqualValues(&apimodel.Relationship{ + ID: "01FHMQX3GAABWSM0S2VZEC2SWC", + Following: false, + ShowingReblogs: false, + Notifying: false, + FollowedBy: true, + Blocking: false, + BlockedBy: false, + Muting: false, + MutingNotifications: false, + Requested: false, + DomainBlocking: false, + Endorsed: false, + Note: "", + }, relationship) // accept should be sent to Some_User var sent [][]byte if !testrig.WaitFor(func() bool { - sentI, ok := suite.httpClient.SentMessages.Load(requestingAccount.InboxURI) + sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI) if ok { sent, ok = sentI.([][]byte) if !ok { @@ -87,41 +109,45 @@ func (suite *FollowRequestTestSuite) TestFollowRequestAccept() { err = json.Unmarshal(sent[0], accept) suite.NoError(err) - suite.Equal(targetAccount.URI, accept.Actor) - suite.Equal(requestingAccount.URI, accept.Object.Actor) + suite.Equal(requestingAccount.URI, accept.Actor) + suite.Equal(targetAccount.URI, accept.Object.Actor) suite.Equal(fr.URI, accept.Object.ID) - suite.Equal(targetAccount.URI, accept.Object.Object) - suite.Equal(targetAccount.URI, accept.Object.To) + suite.Equal(requestingAccount.URI, accept.Object.Object) + suite.Equal(requestingAccount.URI, accept.Object.To) suite.Equal("Follow", accept.Object.Type) - suite.Equal(requestingAccount.URI, accept.To) + suite.Equal(targetAccount.URI, accept.To) suite.Equal("Accept", accept.Type) } func (suite *FollowRequestTestSuite) TestFollowRequestReject() { - requestingAccount := suite.testAccounts["remote_account_2"] - targetAccount := suite.testAccounts["local_account_1"] + requestingAccount := suite.testAccounts["local_account_1"] + targetAccount := suite.testAccounts["remote_account_2"] // put a follow request in the database fr := >smodel.FollowRequest{ ID: "01FJ1S8DX3STJJ6CEYPMZ1M0R3", CreatedAt: time.Now(), UpdatedAt: time.Now(), - URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", requestingAccount.URI), - AccountID: requestingAccount.ID, - TargetAccountID: targetAccount.ID, + URI: fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", targetAccount.URI), + AccountID: targetAccount.ID, + TargetAccountID: requestingAccount.ID, } err := suite.db.Put(context.Background(), fr) suite.NoError(err) - relationship, errWithCode := suite.processor.FollowRequestReject(context.Background(), suite.testAutheds["local_account_1"], requestingAccount.ID) + relationship, errWithCode := suite.processor.Account().FollowRequestReject( + context.Background(), + requestingAccount, + targetAccount.ID, + ) suite.NoError(errWithCode) suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: false, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship) // reject should be sent to Some_User var sent [][]byte if !testrig.WaitFor(func() bool { - sentI, ok := suite.httpClient.SentMessages.Load(requestingAccount.InboxURI) + sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI) if ok { sent, ok = sentI.([][]byte) if !ok { @@ -150,13 +176,13 @@ func (suite *FollowRequestTestSuite) TestFollowRequestReject() { err = json.Unmarshal(sent[0], reject) suite.NoError(err) - suite.Equal(targetAccount.URI, reject.Actor) - suite.Equal(requestingAccount.URI, reject.Object.Actor) + suite.Equal(requestingAccount.URI, reject.Actor) + suite.Equal(targetAccount.URI, reject.Object.Actor) suite.Equal(fr.URI, reject.Object.ID) - suite.Equal(targetAccount.URI, reject.Object.Object) - suite.Equal(targetAccount.URI, reject.Object.To) + suite.Equal(requestingAccount.URI, reject.Object.Object) + suite.Equal(requestingAccount.URI, reject.Object.To) suite.Equal("Follow", reject.Object.Type) - suite.Equal(requestingAccount.URI, reject.To) + suite.Equal(targetAccount.URI, reject.To) suite.Equal("Reject", reject.Type) } diff --git a/internal/processing/processor.go b/internal/processing/processor.go index c0fd15a24..f814d5a96 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -24,6 +24,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing/account" "github.com/superseriousbusiness/gotosocial/internal/processing/admin" + "github.com/superseriousbusiness/gotosocial/internal/processing/common" "github.com/superseriousbusiness/gotosocial/internal/processing/fedi" "github.com/superseriousbusiness/gotosocial/internal/processing/list" "github.com/superseriousbusiness/gotosocial/internal/processing/markers" @@ -147,7 +148,8 @@ func NewProcessor( // // Start with sub processors that will // be required by the workers processor. - accountProcessor := account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc) + commonProcessor := common.New(state, tc, federator, filter) + accountProcessor := account.New(&commonProcessor, state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc) mediaProcessor := media.New(state, tc, mediaManager, federator.TransportController()) streamProcessor := stream.New(state, oauthServer) diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go index 7b31ec977..4522d5858 100644 --- a/internal/timeline/get_test.go +++ b/internal/timeline/get_test.go @@ -66,6 +66,7 @@ func (suite *GetTestSuite) emptyAccountFollows(ctx context.Context, accountID st follows, err := suite.state.DB.GetAccountFollows( gtscontext.SetBarebones(ctx), accountID, + nil, // select all ) if err != nil { suite.FailNow(err.Error()) @@ -82,6 +83,7 @@ func (suite *GetTestSuite) emptyAccountFollows(ctx context.Context, accountID st follows, err = suite.state.DB.GetAccountFollows( gtscontext.SetBarebones(ctx), accountID, + nil, // select all ) if err != nil { suite.FailNow(err.Error()) diff --git a/testrig/testmodels.go b/testrig/testmodels.go index 4f0768b45..fa6ff92ff 100644 --- a/testrig/testmodels.go +++ b/testrig/testmodels.go @@ -364,6 +364,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account { SuspendedAt: time.Time{}, HideCollections: util.Ptr(false), SuspensionOrigin: "", + EnableRSS: util.Ptr(false), }, "admin_account": { ID: "01F8MH17FWEB39HZJ76B6VXSKF", @@ -539,6 +540,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account { SuspendedAt: time.Time{}, HideCollections: util.Ptr(false), SuspensionOrigin: "", + EnableRSS: util.Ptr(false), }, "remote_account_2": { ID: "01FHMQX3GAABWSM0S2VZEC2SWC", @@ -575,6 +577,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account { SuspendedAt: time.Time{}, HideCollections: util.Ptr(false), SuspensionOrigin: "", + EnableRSS: util.Ptr(false), }, "remote_account_3": { ID: "062G5WYKY35KKD12EMSM3F8PJ8", @@ -612,6 +615,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account { HideCollections: util.Ptr(false), SuspensionOrigin: "", HeaderMediaAttachmentID: "01PFPMWK2FF0D9WMHEJHR07C3R", + EnableRSS: util.Ptr(false), }, "remote_account_4": { ID: "07GZRBAEMBNKGZ8Z9VSKSXKR98", diff --git a/vendor/github.com/tomnomnom/linkheader/.gitignore b/vendor/github.com/tomnomnom/linkheader/.gitignore new file mode 100644 index 000000000..0a00ddebb --- /dev/null +++ b/vendor/github.com/tomnomnom/linkheader/.gitignore @@ -0,0 +1,2 @@ +cpu.out +linkheader.test diff --git a/vendor/github.com/tomnomnom/linkheader/.travis.yml b/vendor/github.com/tomnomnom/linkheader/.travis.yml new file mode 100644 index 000000000..cfda08659 --- /dev/null +++ b/vendor/github.com/tomnomnom/linkheader/.travis.yml @@ -0,0 +1,6 @@ +language: go + +go: + - 1.6 + - 1.7 + - tip diff --git a/vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd b/vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd new file mode 100644 index 000000000..0339bec55 --- /dev/null +++ b/vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd @@ -0,0 +1,10 @@ +# Contributing + +* Raise an issue if appropriate +* Fork the repo +* Bootstrap the dev dependencies (run `./script/bootstrap`) +* Make your changes +* Use [gofmt](https://golang.org/cmd/gofmt/) +* Make sure the tests pass (run `./script/test`) +* Make sure the linters pass (run `./script/lint`) +* Issue a pull request diff --git a/vendor/github.com/tomnomnom/linkheader/LICENSE b/vendor/github.com/tomnomnom/linkheader/LICENSE new file mode 100644 index 000000000..55192df56 --- /dev/null +++ b/vendor/github.com/tomnomnom/linkheader/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016 Tom Hudson + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/tomnomnom/linkheader/README.mkd b/vendor/github.com/tomnomnom/linkheader/README.mkd new file mode 100644 index 000000000..2a949cac2 --- /dev/null +++ b/vendor/github.com/tomnomnom/linkheader/README.mkd @@ -0,0 +1,35 @@ +# Golang Link Header Parser + +Library for parsing HTTP Link headers. Requires Go 1.6 or higher. + +Docs can be found on [the GoDoc page](https://godoc.org/github.com/tomnomnom/linkheader). + +[![Build Status](https://travis-ci.org/tomnomnom/linkheader.svg)](https://travis-ci.org/tomnomnom/linkheader) + +## Basic Example + +```go +package main + +import ( + "fmt" + + "github.com/tomnomnom/linkheader" +) + +func main() { + header := "; rel=\"next\"," + + "; rel=\"last\"" + links := linkheader.Parse(header) + + for _, link := range links { + fmt.Printf("URL: %s; Rel: %s\n", link.URL, link.Rel) + } +} + +// Output: +// URL: https://api.github.com/user/58276/repos?page=2; Rel: next +// URL: https://api.github.com/user/58276/repos?page=2; Rel: last +``` + + diff --git a/vendor/github.com/tomnomnom/linkheader/main.go b/vendor/github.com/tomnomnom/linkheader/main.go new file mode 100644 index 000000000..6b81321b8 --- /dev/null +++ b/vendor/github.com/tomnomnom/linkheader/main.go @@ -0,0 +1,151 @@ +// Package linkheader provides functions for parsing HTTP Link headers +package linkheader + +import ( + "fmt" + "strings" +) + +// A Link is a single URL and related parameters +type Link struct { + URL string + Rel string + Params map[string]string +} + +// HasParam returns if a Link has a particular parameter or not +func (l Link) HasParam(key string) bool { + for p := range l.Params { + if p == key { + return true + } + } + return false +} + +// Param returns the value of a parameter if it exists +func (l Link) Param(key string) string { + for k, v := range l.Params { + if key == k { + return v + } + } + return "" +} + +// String returns the string representation of a link +func (l Link) String() string { + + p := make([]string, 0, len(l.Params)) + for k, v := range l.Params { + p = append(p, fmt.Sprintf("%s=\"%s\"", k, v)) + } + if l.Rel != "" { + p = append(p, fmt.Sprintf("%s=\"%s\"", "rel", l.Rel)) + } + return fmt.Sprintf("<%s>; %s", l.URL, strings.Join(p, "; ")) +} + +// Links is a slice of Link structs +type Links []Link + +// FilterByRel filters a group of Links by the provided Rel attribute +func (l Links) FilterByRel(r string) Links { + links := make(Links, 0) + for _, link := range l { + if link.Rel == r { + links = append(links, link) + } + } + return links +} + +// String returns the string representation of multiple Links +// for use in HTTP responses etc +func (l Links) String() string { + if l == nil { + return fmt.Sprint(nil) + } + + var strs []string + for _, link := range l { + strs = append(strs, link.String()) + } + return strings.Join(strs, ", ") +} + +// Parse parses a raw Link header in the form: +// ; rel="foo", ; rel="bar"; wat="dis" +// returning a slice of Link structs +func Parse(raw string) Links { + var links Links + + // One chunk: ; rel="foo" + for _, chunk := range strings.Split(raw, ",") { + + link := Link{URL: "", Rel: "", Params: make(map[string]string)} + + // Figure out what each piece of the chunk is + for _, piece := range strings.Split(chunk, ";") { + + piece = strings.Trim(piece, " ") + if piece == "" { + continue + } + + // URL + if piece[0] == '<' && piece[len(piece)-1] == '>' { + link.URL = strings.Trim(piece, "<>") + continue + } + + // Params + key, val := parseParam(piece) + if key == "" { + continue + } + + // Special case for rel + if strings.ToLower(key) == "rel" { + link.Rel = val + } else { + link.Params[key] = val + } + } + + if link.URL != "" { + links = append(links, link) + } + } + + return links +} + +// ParseMultiple is like Parse, but accepts a slice of headers +// rather than just one header string +func ParseMultiple(headers []string) Links { + links := make(Links, 0) + for _, header := range headers { + links = append(links, Parse(header)...) + } + return links +} + +// parseParam takes a raw param in the form key="val" and +// returns the key and value as seperate strings +func parseParam(raw string) (key, val string) { + + parts := strings.SplitN(raw, "=", 2) + if len(parts) == 1 { + return parts[0], "" + } + if len(parts) != 2 { + return "", "" + } + + key = parts[0] + val = strings.Trim(parts[1], "\"") + + return key, val + +} diff --git a/vendor/modules.txt b/vendor/modules.txt index dae43a87f..99c7b384f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -672,6 +672,9 @@ github.com/tdewolff/parse/v2/strconv # github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc ## explicit github.com/tmthrgd/go-hex +# github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 +## explicit +github.com/tomnomnom/linkheader # github.com/twitchyliquid64/golang-asm v0.15.1 ## explicit; go 1.13 github.com/twitchyliquid64/golang-asm/asm/arch