[feature] Add List functionality (#1802)

* start working on lists

* further list work

* test list db functions nicely

* more work on lists

* peepoopeepoo

* poke

* start list timeline func

* we're getting there lads

* couldn't be me working on stuff... could it?

* hook up handlers

* fiddling

* weeee

* woah

* screaming, pissing

* fix streaming being a whiny baby

* lint, small test fix, swagger

* tidying up, testing

* fucked! by the linter

* move timelines to state like a boss

* add timeline start to tests using state

* invalidate lists
This commit is contained in:
tobi 2023-05-25 10:37:38 +02:00 committed by GitHub
parent 282be6f26d
commit f5c004d67d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
123 changed files with 5654 additions and 970 deletions

View file

@ -32,7 +32,10 @@ import (
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/tracing"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"go.uber.org/automaxprocs/maxprocs"
"github.com/superseriousbusiness/gotosocial/internal/config"
@ -72,7 +75,6 @@ var Start action.GTSAction = func(ctx context.Context) error {
defer state.Caches.Stop()
// Initialize Tracing
if err := tracing.Initialize(); err != nil {
return fmt.Errorf("error initializing tracing: %w", err)
}
@ -110,36 +112,56 @@ var Start action.GTSAction = func(ctx context.Context) error {
state.Workers.Start()
defer state.Workers.Stop()
// build backend handlers
// Build handlers used in later initializations.
mediaManager := media.NewManager(&state)
oauthServer := oauth.New(ctx, dbService)
typeConverter := typeutils.NewConverter(dbService)
filter := visibility.NewFilter(&state)
federatingDB := federatingdb.New(&state, typeConverter)
transportController := transport.NewController(&state, federatingDB, &federation.Clock{}, client)
federator := federation.NewFederator(&state, federatingDB, transportController, typeConverter, mediaManager)
// decide whether to create a noop email sender (won't send emails) or a real one
// Decide whether to create a noop email
// sender (won't send emails) or a real one.
var emailSender email.Sender
if smtpHost := config.GetSMTPHost(); smtpHost != "" {
// host is defined so create a proper sender
// Host is defined; create a proper sender.
emailSender, err = email.NewSender()
if err != nil {
return fmt.Errorf("error creating email sender: %s", err)
}
} else {
// no host is defined so create a noop sender
// No host is defined; create a noop sender.
emailSender, err = email.NewNoopSender(nil)
if err != nil {
return fmt.Errorf("error creating noop email sender: %s", err)
}
}
// create the message processor using the other services we've created so far
processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender)
if err := processor.Start(); err != nil {
return fmt.Errorf("error creating processor: %s", err)
// Initialize timelines.
state.Timelines.Home = timeline.NewManager(
tlprocessor.HomeTimelineGrab(&state),
tlprocessor.HomeTimelineFilter(&state, filter),
tlprocessor.HomeTimelineStatusPrepare(&state, typeConverter),
tlprocessor.SkipInsert(),
)
if err := state.Timelines.Home.Start(); err != nil {
return fmt.Errorf("error starting home timeline: %s", err)
}
state.Timelines.List = timeline.NewManager(
tlprocessor.ListTimelineGrab(&state),
tlprocessor.ListTimelineFilter(&state, filter),
tlprocessor.ListTimelineStatusPrepare(&state, typeConverter),
tlprocessor.SkipInsert(),
)
if err := state.Timelines.List.Start(); err != nil {
return fmt.Errorf("error starting list timeline: %s", err)
}
// Create the processor using all the other services we've created so far.
processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender)
// Set state client / federator worker enqueue functions
state.Workers.EnqueueClientAPI = processor.EnqueueClientAPI
state.Workers.EnqueueFederator = processor.EnqueueFederator

View file

@ -38,9 +38,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/oidc"
tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/tracing"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/internal/web"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -89,11 +92,31 @@ var Start action.GTSAction = func(ctx context.Context) error {
federator := testrig.NewTestFederator(&state, transportController, mediaManager)
emailSender := testrig.NewEmailSender("./web/template/", nil)
typeConverter := testrig.NewTestTypeConverter(state.DB)
filter := visibility.NewFilter(&state)
// Initialize timelines.
state.Timelines.Home = timeline.NewManager(
tlprocessor.HomeTimelineGrab(&state),
tlprocessor.HomeTimelineFilter(&state, filter),
tlprocessor.HomeTimelineStatusPrepare(&state, typeConverter),
tlprocessor.SkipInsert(),
)
if err := state.Timelines.Home.Start(); err != nil {
return fmt.Errorf("error starting home timeline: %s", err)
}
state.Timelines.List = timeline.NewManager(
tlprocessor.ListTimelineGrab(&state),
tlprocessor.ListTimelineFilter(&state, filter),
tlprocessor.ListTimelineStatusPrepare(&state, typeConverter),
tlprocessor.SkipInsert(),
)
if err := state.Timelines.List.Start(); err != nil {
return fmt.Errorf("error starting list timeline: %s", err)
}
processor := testrig.NewTestProcessor(&state, federator, emailSender, mediaManager)
if err := processor.Start(); err != nil {
return fmt.Errorf("error starting processor: %s", err)
}
/*
HTTP router initialization

View file

@ -1635,6 +1635,28 @@ definitions:
type: object
x-go-name: InstanceV2Users
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model
list:
properties:
id:
description: The ID of the list.
type: string
x-go-name: ID
replies_policy:
description: |-
RepliesPolicy for this list.
followed = Show replies to any followed user
list = Show replies to members of the list
none = Show replies to no one
type: string
x-go-name: RepliesPolicy
title:
description: The user-defined title of the list.
type: string
x-go-name: Title
title: List represents a user-created list of accounts that the user follows.
type: object
x-go-name: List
x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model
mediaDimensions:
properties:
aspect:
@ -2881,6 +2903,40 @@ paths:
summary: See accounts followed by given account id.
tags:
- accounts
/api/v1/accounts/{id}/lists:
get:
operationId: accountLists
parameters:
- description: Account ID.
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: Array of all lists containing this account.
schema:
items:
$ref: '#/definitions/list'
type: array
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: See all lists of yours that contain requested account.
tags:
- accounts
/api/v1/accounts/{id}/statuses:
get:
description: The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer).
@ -3211,7 +3267,7 @@ paths:
name: id
required: true
type: string
- description: Type of action to be taken (`disable`, `silence`, or `suspend`).
- description: Type of action to be taken, currently only supports `suspend`.
in: formData
name: type
required: true
@ -4453,6 +4509,343 @@ paths:
description: internal server error
tags:
- instance
/api/v1/list:
post:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: listCreate
parameters:
- description: Title of this list.
example: Cool People
in: formData
name: title
required: true
type: string
x-go-name: Title
- default: list
description: |-
RepliesPolicy for this list.
followed = Show replies to any followed user
list = Show replies to members of the list
none = Show replies to no one
example: list
in: formData
name: replies_policy
type: string
x-go-name: RepliesPolicy
produces:
- application/json
responses:
"200":
description: The newly created list.
schema:
$ref: '#/definitions/list'
"400":
description: bad request
"401":
description: unauthorized
"403":
description: forbidden
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:lists
summary: Create a new list.
tags:
- lists
put:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: listUpdate
parameters:
- description: ID of the list
example: Cool People
in: path
name: id
required: true
type: string
x-go-name: Title
- description: Title of this list.
example: Cool People
in: formData
name: title
type: string
x-go-name: RepliesPolicy
- description: |-
RepliesPolicy for this list.
followed = Show replies to any followed user
list = Show replies to members of the list
none = Show replies to no one
example: list
in: formData
name: replies_policy
type: string
produces:
- application/json
responses:
"200":
description: The newly updated list.
schema:
$ref: '#/definitions/list'
"400":
description: bad request
"401":
description: unauthorized
"403":
description: forbidden
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:lists
summary: Update an existing list.
tags:
- lists
/api/v1/list/{id}:
delete:
operationId: listDelete
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: list deleted
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- write:lists
summary: Delete a single list with the given ID.
tags:
- lists
get:
operationId: list
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: Requested list.
schema:
$ref: '#/definitions/list'
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Get a single list with the given ID.
tags:
- lists
/api/v1/list/{id}/accounts:
delete:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: removeListAccounts
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
- description: Array of accountIDs to modify. Each accountID must correspond to an account that the requesting account follows.
in: formData
items:
type: string
name: account_ids
required: true
type: array
produces:
- application/json
responses:
"200":
description: list accounts updated
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Remove one or more accounts from the given list.
tags:
- lists
get:
description: |-
The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline.
Example:
```
<https://example.org/api/v1/list/01H0W619198FX7J54NF7EH1NG2/accounts?limit=20&max_id=01FC3GSQ8A3MMJ43BPZSGEG29M>; rel="next", <https://example.org/api/v1/list/01H0W619198FX7J54NF7EH1NG2/accounts?limit=20&min_id=01FC3KJW2GYXSDDRA6RWNDM46M>; rel="prev"
````
operationId: listAccounts
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
- description: Return only list entries *OLDER* than the given max ID. The account from the list entry with the specified ID will not be included in the response.
in: query
name: max_id
type: string
- description: Return only list entries *NEWER* than the given since ID. The account from the list entry with the specified ID will not be included in the response.
in: query
name: since_id
type: string
- description: Return only list entries *IMMEDIATELY NEWER* than the given min ID. The account from the list entry with the specified ID will not be included in the response.
in: query
name: min_id
type: string
- default: 20
description: Number of accounts to return.
in: query
name: limit
type: integer
produces:
- application/json
responses:
"200":
description: Array of accounts.
headers:
Link:
description: Links to the next and previous queries.
type: string
schema:
items:
$ref: '#/definitions/account'
type: array
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Page through accounts in this list.
tags:
- lists
post:
consumes:
- application/json
- application/xml
- application/x-www-form-urlencoded
operationId: addListAccounts
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
- description: Array of accountIDs to modify. Each accountID must correspond to an account that the requesting account follows.
in: formData
items:
type: string
name: account_ids
required: true
type: array
produces:
- application/json
responses:
"200":
description: list accounts updated
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Add one or more accounts to the given list.
tags:
- lists
/api/v1/lists:
get:
operationId: lists
produces:
- application/json
responses:
"200":
description: Array of all lists owned by the requesting user.
schema:
items:
$ref: '#/definitions/list'
type: array
"400":
description: bad request
"401":
description: unauthorized
"404":
description: not found
"406":
description: not acceptable
"500":
description: internal server error
security:
- OAuth2 Bearer:
- read:lists
summary: Get all lists for owned by authorized user.
tags:
- lists
/api/v1/media/{id}:
get:
operationId: mediaGet
@ -5579,6 +5972,18 @@ paths:
name: stream
required: true
type: string
- description: |-
ID of the list to subscribe to.
Only used if stream type is 'list'.
in: query
name: list
type: string
- description: |-
Name of the tag to subscribe to.
Only used if stream type is 'hashtag' or 'hashtag:local'.
in: query
name: tag
type: string
produces:
- application/json
responses:
@ -5696,6 +6101,65 @@ paths:
summary: See statuses/posts by accounts you follow.
tags:
- timelines
/api/v1/timelines/list/{id}:
get:
description: |-
The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer).
The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline.
Example:
```
<https://example.org/api/v1/timelines/list/01H0W619198FX7J54NF7EH1NG2?limit=20&max_id=01FC3GSQ8A3MMJ43BPZSGEG29M>; rel="next", <https://example.org/api/v1/timelines/list/01H0W619198FX7J54NF7EH1NG2?limit=20&min_id=01FC3KJW2GYXSDDRA6RWNDM46M>; rel="prev"
````
operationId: listTimeline
parameters:
- description: ID of the list
in: path
name: id
required: true
type: string
- description: Return only statuses *OLDER* than the given max status ID. The status with the specified ID will not be included in the response.
in: query
name: max_id
type: string
- description: Return only statuses *NEWER* than the given since status ID. The status with the specified ID will not be included in the response.
in: query
name: since_id
type: string
- description: Return only statuses *NEWER* than the given since status ID. The status with the specified ID will not be included in the response.
in: query
name: min_id
type: string
- default: 20
description: Number of statuses to return.
in: query
name: limit
type: integer
produces:
- application/json
responses:
"200":
description: Array of statuses.
headers:
Link:
description: Links to the next and previous queries.
type: string
schema:
items:
$ref: '#/definitions/status'
type: array
"400":
description: bad request
"401":
description: unauthorized
security:
- OAuth2 Bearer:
- read:lists
summary: See statuses/posts from the given list timeline.
tags:
- timelines
/api/v1/timelines/public:
get:
description: |-
@ -5980,6 +6444,7 @@ securityDefinitions:
read:custom_emojis: grant read access to custom_emojis
read:favourites: grant read access to favourites
read:follows: grant read access to follows
read:lists: grant read access to lists
read:media: grant read access to media
read:notifications: grants read access to notifications
read:search: grant read access to searches
@ -5990,6 +6455,7 @@ securityDefinitions:
write:accounts: grants write access to accounts
write:blocks: grants write access to blocks
write:follows: grants write access to follows
write:lists: grants write access to lists
write:media: grants write access to media
write:statuses: grants write access to statuses
write:user: grants write access to user-level info

View file

@ -37,6 +37,7 @@
// read:custom_emojis: grant read access to custom_emojis
// read:favourites: grant read access to favourites
// read:follows: grant read access to follows
// read:lists: grant read access to lists
// read:media: grant read access to media
// read:search: grant read access to searches
// read:statuses: grants read access to statuses
@ -47,6 +48,7 @@
// write:accounts: grants write access to accounts
// write:blocks: grants write access to blocks
// write:follows: grants write access to follows
// write:lists: grants write access to lists
// write:media: grants write access to media
// write:statuses: grants write access to statuses
// write:user: grants write access to user-level info

View file

@ -289,6 +289,14 @@ cache:
follow-request-ttl: "30m"
follow-request-sweep-freq: "1m"
list-max-size: 2000
list-ttl: "30m"
list-sweep-freq: "1m"
list-entry-max-size: 2000
list-entry-ttl: "30m"
list-entry-sweep-freq: "1m"
media-max-size: 1000
media-ttl: "30m"
media-sweep-freq: "1m"

View file

@ -36,6 +36,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -74,9 +75,14 @@ func (suite *EmojiGetTestSuite) SetupTest() {
suite.state.DB = suite.db
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
@ -86,8 +92,6 @@ func (suite *EmojiGetTestSuite) SetupTest() {
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.signatureCheck = middleware.SignatureCheck(suite.db.IsURIBlocked)
suite.NoError(suite.processor.Start())
}
func (suite *EmojiGetTestSuite) TearDownTest() {

View file

@ -89,7 +89,6 @@ func (suite *InboxPostTestSuite) TestPostBlock() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -190,7 +189,6 @@ func (suite *InboxPostTestSuite) TestPostUnblock() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -296,7 +294,6 @@ func (suite *InboxPostTestSuite) TestPostUpdate() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -425,7 +422,6 @@ func (suite *InboxPostTestSuite) TestPostDelete() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()

View file

@ -106,7 +106,6 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -181,7 +180,6 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()

View file

@ -106,7 +106,6 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()
@ -171,7 +170,6 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() {
emailSender := testrig.NewEmailSender("../../../../web/template/", nil)
processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager)
userModule := users.New(processor)
suite.NoError(processor.Start())
// setup request
recorder := httptest.NewRecorder()

View file

@ -31,6 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -83,6 +84,13 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
@ -94,8 +102,6 @@ func (suite *UserStandardTestSuite) SetupTest() {
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.signatureCheck = middleware.SignatureCheck(suite.db.IsURIBlocked)
suite.NoError(suite.processor.Start())
}
func (suite *UserStandardTestSuite) TearDownTest() {

View file

@ -36,6 +36,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -86,6 +87,12 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
@ -94,8 +101,6 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.accountsModule = accounts.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *AccountStandardTestSuite) TearDownTest() {

View file

@ -70,6 +70,8 @@ const (
UnblockPath = BasePathWithID + "/unblock"
// DeleteAccountPath is for deleting one's account via the API
DeleteAccountPath = BasePath + "/delete"
// ListsPath is for seeing which lists an account is.
ListsPath = BasePathWithID + "/lists"
)
type Module struct {
@ -115,4 +117,7 @@ func (m *Module) Route(attachHandler func(method string, path string, f ...gin.H
// block or unblock account
attachHandler(http.MethodPost, BlockPath, m.AccountBlockPOSTHandler)
attachHandler(http.MethodPost, UnblockPath, m.AccountUnblockPOSTHandler)
// account lists
attachHandler(http.MethodGet, ListsPath, m.AccountListsGETHandler)
}

View file

@ -0,0 +1,97 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package accounts
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// AccountListsGETHandler swagger:operation GET /api/v1/accounts/{id}/lists accountLists
//
// See all lists of yours that contain requested account.
//
// ---
// tags:
// - accounts
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: Account ID.
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// name: lists
// description: Array of all lists containing this account.
// schema:
// type: array
// items:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) AccountListsGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, false, false, false, false)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetAcctID := c.Param(IDKey)
if targetAcctID == "" {
err := errors.New("no account id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
lists, errWithCode := m.processor.Account().ListsGet(c.Request.Context(), authed.Account, targetAcctID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, lists)
}

View file

@ -0,0 +1,103 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package accounts_test
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/suite"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type ListsTestSuite struct {
AccountStandardTestSuite
}
func (suite *ListsTestSuite) getLists(targetAccountID string, expectedHTTPStatus int, expectedBody string) []*apimodel.List {
var (
recorder = httptest.NewRecorder()
ctx, _ = testrig.CreateGinTestContext(recorder, nil)
request = httptest.NewRequest(http.MethodGet, "http://localhost:8080/api/v1/accounts/"+targetAccountID+"/lists", nil)
)
// Set up the test context.
ctx.Request = request
ctx.AddParam("id", targetAccountID)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
// Trigger the handler.
suite.accountsModule.AccountListsGETHandler(ctx)
// Read the result.
result := recorder.Result()
defer result.Body.Close()
b, err := io.ReadAll(result.Body)
if err != nil {
suite.FailNow(err.Error())
}
errs := gtserror.MultiError{}
// Check expected code + body.
if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
}
// If we got an expected body, return early.
if expectedBody != "" && string(b) != expectedBody {
errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
}
if err := errs.Combine(); err != nil {
suite.FailNow("", "%v (body %s)", err, string(b))
}
// Return list response.
resp := new([]*apimodel.List)
if err := json.Unmarshal(b, resp); err != nil {
suite.FailNow(err.Error())
}
return *resp
}
func (suite *ListsTestSuite) TestGetListsHit() {
targetAccount := suite.testAccounts["admin_account"]
suite.getLists(targetAccount.ID, http.StatusOK, `[{"id":"01H0G8E4Q2J3FE3JDWJVWEDCD1","title":"Cool Ass Posters From This Instance","replies_policy":"followed"}]`)
}
func (suite *ListsTestSuite) TestGetListsNoHit() {
targetAccount := suite.testAccounts["remote_account_1"]
suite.getLists(targetAccount.ID, http.StatusOK, `[]`)
}
func TestListsTestSuite(t *testing.T) {
suite.Run(t, new(ListsTestSuite))
}

View file

@ -36,6 +36,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -92,6 +93,12 @@ func (suite *AdminStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)

View file

@ -42,6 +42,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -98,6 +99,13 @@ func (suite *BookmarkTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -107,8 +115,6 @@ func (suite *BookmarkTestSuite) SetupTest() {
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.statusModule = statuses.New(suite.processor)
suite.bookmarkModule = bookmarks.New(suite.processor)
suite.NoError(suite.processor.Start())
}
func (suite *BookmarkTestSuite) TearDownTest() {

View file

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -82,6 +83,13 @@ func (suite *FavouritesStandardTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -90,8 +98,6 @@ func (suite *FavouritesStandardTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.favModule = favourites.New(suite.processor)
suite.NoError(suite.processor.Start())
}
func (suite *FavouritesStandardTestSuite) TearDownTest() {

View file

@ -128,7 +128,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) {
limit = int(i)
}
resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit)
resp, errWithCode := m.processor.Timeline().FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -35,6 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -83,6 +84,12 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
@ -90,8 +97,6 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() {
suite.followRequestModule = followrequests.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *FollowRequestStandardTestSuite) TearDownTest() {

View file

@ -35,6 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -85,6 +86,12 @@ func (suite *InstanceStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)

View file

@ -25,8 +25,15 @@ import (
)
const (
IDKey = "id"
// BasePath is the base path for serving the lists API, minus the 'api' prefix
BasePath = "/v1/lists"
BasePath = "/v1/lists"
BasePathWithID = BasePath + "/:" + IDKey
AccountsPath = BasePathWithID + "/accounts"
MaxIDKey = "max_id"
LimitKey = "limit"
SinceIDKey = "since_id"
MinIDKey = "min_id"
)
type Module struct {
@ -40,5 +47,15 @@ func New(processor *processing.Processor) *Module {
}
func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
// create / get / update / delete lists
attachHandler(http.MethodPost, BasePath, m.ListCreatePOSTHandler)
attachHandler(http.MethodGet, BasePath, m.ListsGETHandler)
attachHandler(http.MethodGet, BasePathWithID, m.ListGETHandler)
attachHandler(http.MethodPut, BasePathWithID, m.ListUpdatePUTHandler)
attachHandler(http.MethodDelete, BasePathWithID, m.ListDELETEHandler)
// get / add / remove list accounts
attachHandler(http.MethodGet, AccountsPath, m.ListAccountsGETHandler)
attachHandler(http.MethodPost, AccountsPath, m.ListAccountsPOSTHandler)
attachHandler(http.MethodDelete, AccountsPath, m.ListAccountsDELETEHandler)
}

View file

@ -0,0 +1,156 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListAccountsGETHandler swagger:operation GET /api/v1/list/{id}/accounts listAccounts
//
// Page through accounts in this list.
//
// The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline.
//
// Example:
//
// ```
// <https://example.org/api/v1/list/01H0W619198FX7J54NF7EH1NG2/accounts?limit=20&max_id=01FC3GSQ8A3MMJ43BPZSGEG29M>; rel="next", <https://example.org/api/v1/list/01H0W619198FX7J54NF7EH1NG2/accounts?limit=20&min_id=01FC3KJW2GYXSDDRA6RWNDM46M>; rel="prev"
// ````
//
// ---
// tags:
// - lists
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: max_id
// type: string
// description: >-
// Return only list entries *OLDER* than the given max ID.
// The account from the list entry with the specified ID will not be included in the response.
// in: query
// required: false
// -
// name: since_id
// type: string
// description: >-
// Return only list entries *NEWER* than the given since ID.
// The account from the list entry with the specified ID will not be included in the response.
// in: query
// -
// name: min_id
// type: string
// description: >-
// Return only list entries *IMMEDIATELY NEWER* than the given min ID.
// The account from the list entry with the specified ID will not be included in the response.
// in: query
// required: false
// -
// name: limit
// type: integer
// description: Number of accounts to return.
// default: 20
// in: query
// required: false
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// headers:
// Link:
// type: string
// description: Links to the next and previous queries.
// name: accounts
// description: Array of accounts.
// schema:
// type: array
// items:
// "$ref": "#/definitions/account"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListAccountsGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
resp, errWithCode := m.processor.List().GetListAccounts(
c.Request.Context(),
authed.Account,
targetListID,
c.Query(MaxIDKey),
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if resp.LinkHeader != "" {
c.Header("Link", resp.LinkHeader)
}
c.JSON(http.StatusOK, resp.Items)
}

View file

@ -0,0 +1,120 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListAccountsPOSTHandler swagger:operation POST /api/v1/list/{id}/accounts addListAccounts
//
// Add one or more accounts to the given list.
//
// ---
// tags:
// - lists
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: account_ids
// type: array
// items:
// type: string
// description: >-
// Array of accountIDs to modify.
// Each accountID must correspond to an account
// that the requesting account follows.
// in: formData
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// description: list accounts updated
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListAccountsPOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.ListAccountsChangeRequest{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if len(form.AccountIDs) == 0 {
err := errors.New("no account IDs given")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if errWithCode := m.processor.List().AddToList(c.Request.Context(), authed.Account, targetListID, form.AccountIDs); errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, gin.H{})
}

View file

@ -0,0 +1,120 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListAccountsDELETEHandler swagger:operation DELETE /api/v1/list/{id}/accounts removeListAccounts
//
// Remove one or more accounts from the given list.
//
// ---
// tags:
// - lists
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: account_ids
// type: array
// items:
// type: string
// description: >-
// Array of accountIDs to modify.
// Each accountID must correspond to an account
// that the requesting account follows.
// in: formData
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// description: list accounts updated
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListAccountsDELETEHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.ListAccountsChangeRequest{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if len(form.AccountIDs) == 0 {
err := errors.New("no account IDs given")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if errWithCode := m.processor.List().RemoveFromList(c.Request.Context(), authed.Account, targetListID, form.AccountIDs); errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, gin.H{})
}

View file

@ -0,0 +1,106 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/validate"
)
// ListCreatePOSTHandler swagger:operation POST /api/v1/list listCreate
//
// Create a new list.
//
// ---
// tags:
// - lists
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - write:lists
//
// responses:
// '200':
// description: "The newly created list."
// schema:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListCreatePOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.ListCreateRequest{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if err := validate.ListTitle(form.Title); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
repliesPolicy := gtsmodel.RepliesPolicy(strings.ToLower(form.RepliesPolicy))
if err := validate.ListRepliesPolicy(repliesPolicy); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
apiList, errWithCode := m.processor.List().Create(c.Request.Context(), authed.Account, form.Title, repliesPolicy)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, apiList)
}

View file

@ -0,0 +1,91 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListDELETEHandler swagger:operation DELETE /api/v1/list/{id} listDelete
//
// Delete a single list with the given ID.
//
// ---
// tags:
// - lists
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - write:lists
//
// responses:
// '200':
// description: list deleted
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListDELETEHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if errWithCode := m.processor.List().Delete(c.Request.Context(), authed.Account, targetListID); errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, gin.H{})
}

View file

@ -0,0 +1,95 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListGETHandler swagger:operation GET /api/v1/list/{id} list
//
// Get a single list with the given ID.
//
// ---
// tags:
// - lists
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// name: list
// description: Requested list.
// schema:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
resp, errWithCode := m.processor.List().Get(c.Request.Context(), authed.Account, targetListID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, resp)
}

View file

@ -26,9 +26,42 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListsGETHandler returns a list of lists created by/for the authed account
// ListsGETHandler swagger:operation GET /api/v1/lists lists
//
// Get all lists for owned by authorized user.
//
// ---
// tags:
// - lists
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// name: lists
// description: Array of all lists owned by the requesting user.
// schema:
// type: array
// items:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListsGETHandler(c *gin.Context) {
if _, err := oauth.Authed(c, true, true, true, true); err != nil {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
@ -38,6 +71,11 @@ func (m *Module) ListsGETHandler(c *gin.Context) {
return
}
// todo: implement this; currently it's a no-op
c.JSON(http.StatusOK, []string{})
lists, errWithCode := m.processor.List().GetAll(c.Request.Context(), authed.Account)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, lists)
}

View file

@ -0,0 +1,152 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package lists
import (
"errors"
"net/http"
"strings"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/validate"
)
// ListUpdatePUTHandler swagger:operation PUT /api/v1/list listUpdate
//
// Update an existing list.
//
// ---
// tags:
// - lists
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: title
// type: string
// description: Title of this list.
// in: formData
// example: Cool People
// -
// name: replies_policy
// type: string
// description: |-
// RepliesPolicy for this list.
// followed = Show replies to any followed user
// list = Show replies to members of the list
// none = Show replies to no one
// in: formData
// example: list
//
// security:
// - OAuth2 Bearer:
// - write:lists
//
// responses:
// '200':
// description: "The newly updated list."
// schema:
// "$ref": "#/definitions/list"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '406':
// description: not acceptable
// '500':
// description: internal server error
func (m *Module) ListUpdatePUTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
form := &apimodel.ListUpdateRequest{}
if err := c.ShouldBind(form); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
if form.Title != nil {
if err := validate.ListTitle(*form.Title); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
}
var repliesPolicy *gtsmodel.RepliesPolicy
if form.RepliesPolicy != nil {
rp := gtsmodel.RepliesPolicy(strings.ToLower(*form.RepliesPolicy))
if err := validate.ListRepliesPolicy(rp); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
repliesPolicy = &rp
}
if form.Title == nil && repliesPolicy == nil {
err = errors.New("neither title nor replies_policy was set; nothing to update")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
apiList, errWithCode := m.processor.List().Update(c.Request.Context(), authed.Account, targetListID, form.Title, repliesPolicy)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.JSON(http.StatusOK, apiList)
}

View file

@ -44,6 +44,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -90,6 +91,13 @@ func (suite *MediaCreateTestSuite) SetupSuite() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)

View file

@ -42,6 +42,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -87,6 +88,13 @@ func (suite *MediaUpdateTestSuite) SetupSuite() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)

View file

@ -77,7 +77,7 @@ func (m *Module) NotificationGETHandler(c *gin.Context) {
return
}
resp, errWithCode := m.processor.NotificationGet(c.Request.Context(), authed.Account, targetNotifID)
resp, errWithCode := m.processor.Timeline().NotificationGet(c.Request.Context(), authed.Account, targetNotifID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -69,7 +69,7 @@ func (m *Module) NotificationsClearPOSTHandler(c *gin.Context) {
return
}
errWithCode := m.processor.NotificationsClear(c.Request.Context(), authed)
errWithCode := m.processor.Timeline().NotificationsClear(c.Request.Context(), authed)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -138,7 +138,7 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) {
limit = int(i)
}
resp, errWithCode := m.processor.NotificationsGet(
resp, errWithCode := m.processor.Timeline().NotificationsGet(
c.Request.Context(),
authed,
c.Query(MaxIDKey),

View file

@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -77,6 +78,12 @@ func (suite *ReportsStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
@ -85,8 +92,6 @@ func (suite *ReportsStandardTestSuite) SetupTest() {
suite.reportsModule = reports.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *ReportsStandardTestSuite) TearDownTest() {

View file

@ -35,6 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -81,6 +82,12 @@ func (suite *SearchStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
@ -89,8 +96,6 @@ func (suite *SearchStandardTestSuite) SetupTest() {
suite.searchModule = search.New(suite.processor)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *SearchStandardTestSuite) TearDownTest() {

View file

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -83,6 +84,12 @@ func (suite *StatusStandardTestSuite) SetupTest() {
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -91,8 +98,6 @@ func (suite *StatusStandardTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.statusModule = statuses.New(suite.processor)
suite.NoError(suite.processor.Start())
}
func (suite *StatusStandardTestSuite) TearDownTest() {

View file

@ -82,6 +82,20 @@ import (
// `direct`: receive updates for direct messages.
// in: query
// required: true
// -
// name: list
// type: string
// description: |-
// ID of the list to subscribe to.
// Only used if stream type is 'list'.
// in: query
// -
// name: tag
// type: string
// description: |-
// Name of the tag to subscribe to.
// Only used if stream type is 'hashtag' or 'hashtag:local'.
// in: query
//
// security:
// - OAuth2 Bearer:
@ -164,8 +178,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}
// Get the initial stream type, if there is one.
// streamType will be an empty string if one wasn't supplied. Open() will deal with this
// By appending other query params to the streamType,
// we can allow for streaming for specific list IDs
// or hashtags.
streamType := c.Query(StreamQueryKey)
if list := c.Query(StreamListKey); list != "" {
streamType += ":" + list
} else if tag := c.Query(StreamTagKey); tag != "" {
streamType += ":" + tag
}
stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
@ -240,28 +262,41 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// If the message contains 'stream' and 'type' fields, we can
// update the set of timelines that are subscribed for events.
// everything else is ignored.
action := msg["type"]
streamType := msg["stream"]
// Ignore if the streamType is unknown (or missing), so a bad
// client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, streamType) {
l.Warnf("Unknown 'stream' field: %v", msg)
updateType, ok := msg["type"]
if !ok {
l.Warn("'type' field not provided")
continue
}
switch action {
updateStream, ok := msg["stream"]
if !ok {
l.Warn("'stream' field not provided")
continue
}
// Ignore if the updateStreamType is unknown (or missing),
// so a bad client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
l.Warnf("unknown 'stream' field: %v", msg)
continue
}
updateList, ok := msg["list"]
if ok {
updateStream += ":" + updateList
}
switch updateType {
case "subscribe":
stream.Lock()
stream.Timelines[streamType] = true
stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
delete(stream.Timelines, streamType)
delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
l.Warnf("Invalid 'type' field: %v", msg)
l.Warnf("invalid 'type' field: %v", msg)
}
}
}()
@ -276,7 +311,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
case msg := <-stream.Messages:
l.Tracef("sending message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Errorf("error writing json to websocket: %v", err)
l.Debugf("error writing json to websocket: %v", err)
return
}
@ -290,7 +325,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
websocket.PingMessage,
[]byte{},
); err != nil {
l.Errorf("error writing ping to websocket: %v", err)
l.Debugf("error writing ping to websocket: %v", err)
return
}
}

View file

@ -27,17 +27,12 @@ import (
)
const (
// BasePath is the path for the streaming api, minus the 'api' prefix
BasePath = "/v1/streaming"
// StreamQueryKey is the query key for the type of stream being requested
StreamQueryKey = "stream"
// AccessTokenQueryKey is the query key for an oauth access token that should be passed in streaming requests.
AccessTokenQueryKey = "access_token"
// AccessTokenHeader is the header for an oauth access token that can be passed in streaming requests instead of AccessTokenQueryKey
//nolint:gosec
AccessTokenHeader = "Sec-Websocket-Protocol"
BasePath = "/v1/streaming" // path for the streaming api, minus the 'api' prefix
StreamQueryKey = "stream" // type of stream being requested
StreamListKey = "list" // id of list being requested
StreamTagKey = "tag" // name of tag being requested
AccessTokenQueryKey = "access_token" // oauth access token
AccessTokenHeader = "Sec-Websocket-Protocol" //nolint:gosec
)
type Module struct {

View file

@ -41,6 +41,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -94,6 +95,13 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
@ -102,7 +110,6 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.streamingModule = streaming.New(suite.processor, 1, 4096)
suite.NoError(suite.processor.Start())
}
func (suite *StreamingTestSuite) TearDownTest() {

View file

@ -18,9 +18,7 @@
package timelines
import (
"fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
@ -120,49 +118,27 @@ func (m *Module) HomeTimelineGETHandler(c *gin.Context) {
return
}
maxID := ""
maxIDString := c.Query(MaxIDKey)
if maxIDString != "" {
maxID = maxIDString
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
sinceID := ""
sinceIDString := c.Query(SinceIDKey)
if sinceIDString != "" {
sinceID = sinceIDString
local, errWithCode := apiutil.ParseLocal(c.Query(apiutil.LocalKey), false)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
minID := ""
minIDString := c.Query(MinIDKey)
if minIDString != "" {
minID = minIDString
}
limit := 20
limitString := c.Query(LimitKey)
if limitString != "" {
i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
limit = int(i)
}
local := false
localString := c.Query(LocalKey)
if localString != "" {
i, err := strconv.ParseBool(localString)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LocalKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
local = i
}
resp, errWithCode := m.processor.HomeTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local)
resp, errWithCode := m.processor.Timeline().HomeTimelineGet(
c.Request.Context(),
authed,
c.Query(MaxIDKey),
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
local,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -0,0 +1,152 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timelines
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// ListTimelineGETHandler swagger:operation GET /api/v1/timelines/list/{id} listTimeline
//
// See statuses/posts from the given list timeline.
//
// The statuses will be returned in descending chronological order (newest first), with sequential IDs (bigger = newer).
//
// The returned Link header can be used to generate the previous and next queries when scrolling up or down a timeline.
//
// Example:
//
// ```
// <https://example.org/api/v1/timelines/list/01H0W619198FX7J54NF7EH1NG2?limit=20&max_id=01FC3GSQ8A3MMJ43BPZSGEG29M>; rel="next", <https://example.org/api/v1/timelines/list/01H0W619198FX7J54NF7EH1NG2?limit=20&min_id=01FC3KJW2GYXSDDRA6RWNDM46M>; rel="prev"
// ````
//
// ---
// tags:
// - timelines
//
// produces:
// - application/json
//
// parameters:
// -
// name: id
// type: string
// description: ID of the list
// in: path
// required: true
// -
// name: max_id
// type: string
// description: >-
// Return only statuses *OLDER* than the given max status ID.
// The status with the specified ID will not be included in the response.
// in: query
// required: false
// -
// name: since_id
// type: string
// description: >-
// Return only statuses *NEWER* than the given since status ID.
// The status with the specified ID will not be included in the response.
// in: query
// -
// name: min_id
// type: string
// description: >-
// Return only statuses *NEWER* than the given since status ID.
// The status with the specified ID will not be included in the response.
// in: query
// required: false
// -
// name: limit
// type: integer
// description: Number of statuses to return.
// default: 20
// in: query
// required: false
//
// security:
// - OAuth2 Bearer:
// - read:lists
//
// responses:
// '200':
// name: statuses
// description: Array of statuses.
// schema:
// type: array
// items:
// "$ref": "#/definitions/status"
// headers:
// Link:
// type: string
// description: Links to the next and previous queries.
// '401':
// description: unauthorized
// '400':
// description: bad request
func (m *Module) ListTimelineGETHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1)
return
}
targetListID := c.Param(IDKey)
if targetListID == "" {
err := errors.New("no list id specified")
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
resp, errWithCode := m.processor.Timeline().ListTimelineGet(
c.Request.Context(),
authed,
targetListID,
c.Query(MaxIDKey),
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if resp.LinkHeader != "" {
c.Header("Link", resp.LinkHeader)
}
c.JSON(http.StatusOK, resp.Items)
}

View file

@ -18,9 +18,7 @@
package timelines
import (
"fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
@ -131,49 +129,27 @@ func (m *Module) PublicTimelineGETHandler(c *gin.Context) {
return
}
maxID := ""
maxIDString := c.Query(MaxIDKey)
if maxIDString != "" {
maxID = maxIDString
limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
sinceID := ""
sinceIDString := c.Query(SinceIDKey)
if sinceIDString != "" {
sinceID = sinceIDString
local, errWithCode := apiutil.ParseLocal(c.Query(apiutil.LocalKey), false)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
minID := ""
minIDString := c.Query(MinIDKey)
if minIDString != "" {
minID = minIDString
}
limit := 20
limitString := c.Query(LimitKey)
if limitString != "" {
i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
limit = int(i)
}
local := false
localString := c.Query(LocalKey)
if localString != "" {
i, err := strconv.ParseBool(localString)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LocalKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
local = i
}
resp, errWithCode := m.processor.PublicTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local)
resp, errWithCode := m.processor.Timeline().PublicTimelineGet(
c.Request.Context(),
authed,
c.Query(MaxIDKey),
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
local,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return

View file

@ -27,10 +27,12 @@ import (
const (
// BasePath is the base URI path for serving timelines, minus the 'api' prefix.
BasePath = "/v1/timelines"
IDKey = "id"
// HomeTimeline is the path for the home timeline
HomeTimeline = BasePath + "/home"
// PublicTimeline is the path for the public (and public local) timeline
PublicTimeline = BasePath + "/public"
ListTimeline = BasePath + "/list/:" + IDKey
// MaxIDKey is the url query for setting a max status ID to return
MaxIDKey = "max_id"
// SinceIDKey is the url query for returning results newer than the given ID
@ -56,4 +58,5 @@ func New(processor *processing.Processor) *Module {
func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) {
attachHandler(http.MethodGet, HomeTimeline, m.HomeTimelineGETHandler)
attachHandler(http.MethodGet, PublicTimeline, m.PublicTimelineGETHandler)
attachHandler(http.MethodGet, ListTimeline, m.ListTimelineGETHandler)
}

View file

@ -29,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -73,6 +74,13 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
suite.sentEmails = make(map[string]string)
@ -81,8 +89,6 @@ func (suite *UserStandardTestSuite) SetupTest() {
suite.userModule = user.New(suite.processor)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *UserStandardTestSuite) TearDownTest() {

View file

@ -33,6 +33,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -81,6 +82,13 @@ func (suite *FileserverTestSuite) SetupSuite() {
suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil)

View file

@ -17,14 +17,57 @@
package model
// List represents a list of some users that the authenticated user follows.
// List represents a user-created list of accounts that the user follows.
//
// swagger:model list
type List struct {
// The internal database ID of the list.
// The ID of the list.
ID string `json:"id"`
// The user-defined title of the list.
Title string `json:"title"`
// followed = Show replies to any followed user
// RepliesPolicy for this list.
// followed = Show replies to any followed user
// list = Show replies to members of the list
// none = Show replies to no one
RepliesPolicy string `json:"replies_policy"`
}
// ListCreateRequest models list creation parameters.
//
// swagger:parameters listCreate
type ListCreateRequest struct {
// Title of this list.
// example: Cool People
// in: formData
// required: true
Title string `form:"title" json:"title" xml:"title"`
// RepliesPolicy for this list.
// followed = Show replies to any followed user
// list = Show replies to members of the list
// none = Show replies to no one
// example: list
// default: list
// in: formData
RepliesPolicy string `form:"replies_policy" json:"replies_policy" xml:"replies_policy"`
}
// ListUpdateRequest models list update parameters.
//
// swagger:parameters listUpdate
type ListUpdateRequest struct {
// Title of this list.
// example: Cool People
// in: formData
Title *string `form:"title" json:"title" xml:"title"`
// RepliesPolicy for this list.
// followed = Show replies to any followed user
// list = Show replies to members of the list
// none = Show replies to no one
// in: formData
RepliesPolicy *string `form:"replies_policy" json:"replies_policy" xml:"replies_policy"`
}
// swagger:ignore
type ListAccountsChangeRequest struct {
AccountIDs []string `form:"account_ids[]" json:"account_ids" xml:"account_ids"`
}

View file

@ -0,0 +1,58 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package util
import (
"fmt"
"strconv"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
)
const (
LimitKey = "limit"
LocalKey = "local"
)
func ParseLimit(limit string, defaultLimit int) (int, gtserror.WithCode) {
if limit == "" {
return defaultLimit, nil
}
i, err := strconv.Atoi(limit)
if err != nil {
err := fmt.Errorf("error parsing %s: %w", LimitKey, err)
return 0, gtserror.NewErrorBadRequest(err, err.Error())
}
return i, nil
}
func ParseLocal(local string, defaultLocal bool) (bool, gtserror.WithCode) {
if local == "" {
return defaultLocal, nil
}
i, err := strconv.ParseBool(local)
if err != nil {
err := fmt.Errorf("error parsing %s: %w", LocalKey, err)
return false, gtserror.NewErrorBadRequest(err, err.Error())
}
return i, nil
}

View file

@ -30,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -79,6 +80,13 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
@ -89,8 +97,6 @@ func (suite *WebfingerStandardTestSuite) SetupTest() {
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
suite.NoError(suite.processor.Start())
}
func (suite *WebfingerStandardTestSuite) TearDownTest() {

43
internal/cache/gts.go vendored
View file

@ -35,6 +35,8 @@ type GTSCaches struct {
emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
follow *result.Cache[*gtsmodel.Follow]
followRequest *result.Cache[*gtsmodel.FollowRequest]
list *result.Cache[*gtsmodel.List]
listEntry *result.Cache[*gtsmodel.ListEntry]
media *result.Cache[*gtsmodel.MediaAttachment]
mention *result.Cache[*gtsmodel.Mention]
notification *result.Cache[*gtsmodel.Notification]
@ -57,6 +59,8 @@ func (c *GTSCaches) Init() {
c.initEmojiCategory()
c.initFollow()
c.initFollowRequest()
c.initList()
c.initListEntry()
c.initMedia()
c.initMention()
c.initNotification()
@ -76,6 +80,8 @@ func (c *GTSCaches) Start() {
tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStart(c.follow, config.GetCacheGTSFollowSweepFreq())
tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
tryStart(c.list, config.GetCacheGTSListSweepFreq())
tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
tryStart(c.media, config.GetCacheGTSMediaSweepFreq())
tryStart(c.mention, config.GetCacheGTSMentionSweepFreq())
tryStart(c.notification, config.GetCacheGTSNotificationSweepFreq())
@ -100,6 +106,8 @@ func (c *GTSCaches) Stop() {
tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
tryStop(c.follow, config.GetCacheGTSFollowSweepFreq())
tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
tryStop(c.list, config.GetCacheGTSListSweepFreq())
tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
tryStop(c.media, config.GetCacheGTSMediaSweepFreq())
tryStop(c.mention, config.GetCacheGTSNotificationSweepFreq())
tryStop(c.notification, config.GetCacheGTSNotificationSweepFreq())
@ -146,6 +154,16 @@ func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] {
return c.followRequest
}
// List provides access to the gtsmodel List database cache.
func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] {
return c.list
}
// ListEntry provides access to the gtsmodel ListEntry database cache.
func (c *GTSCaches) ListEntry() *result.Cache[*gtsmodel.ListEntry] {
return c.listEntry
}
// Media provides access to the gtsmodel Media database cache.
func (c *GTSCaches) Media() *result.Cache[*gtsmodel.MediaAttachment] {
return c.media
@ -283,6 +301,30 @@ func (c *GTSCaches) initFollowRequest() {
c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true)
}
func (c *GTSCaches) initList() {
c.list = result.New([]result.Lookup{
{Name: "ID"},
}, func(l1 *gtsmodel.List) *gtsmodel.List {
l2 := new(gtsmodel.List)
*l2 = *l1
return l2
}, config.GetCacheGTSListMaxSize())
c.list.SetTTL(config.GetCacheGTSListTTL(), true)
c.list.IgnoreErrors(ignoreErrors)
}
func (c *GTSCaches) initListEntry() {
c.listEntry = result.New([]result.Lookup{
{Name: "ID"},
}, func(l1 *gtsmodel.ListEntry) *gtsmodel.ListEntry {
l2 := new(gtsmodel.ListEntry)
*l2 = *l1
return l2
}, config.GetCacheGTSListEntryMaxSize())
c.list.SetTTL(config.GetCacheGTSListEntryTTL(), true)
c.list.IgnoreErrors(ignoreErrors)
}
func (c *GTSCaches) initMedia() {
c.media = result.New([]result.Lookup{
{Name: "ID"},
@ -359,7 +401,6 @@ func (c *GTSCaches) initStatusFave() {
c.status.IgnoreErrors(ignoreErrors)
}
// initTombstone will initialize the gtsmodel.Tombstone cache.
func (c *GTSCaches) initTombstone() {
c.tombstone = result.New([]result.Lookup{
{Name: "ID"},

View file

@ -199,6 +199,14 @@ type GTSCacheConfiguration struct {
FollowRequestTTL time.Duration `name:"follow-request-ttl"`
FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"`
ListMaxSize int `name:"list-max-size"`
ListTTL time.Duration `name:"list-ttl"`
ListSweepFreq time.Duration `name:"list-sweep-freq"`
ListEntryMaxSize int `name:"list-entry-max-size"`
ListEntryTTL time.Duration `name:"list-entry-ttl"`
ListEntrySweepFreq time.Duration `name:"list-entry-sweep-freq"`
MediaMaxSize int `name:"media-max-size"`
MediaTTL time.Duration `name:"media-ttl"`
MediaSweepFreq time.Duration `name:"media-sweep-freq"`

View file

@ -153,6 +153,14 @@ var Defaults = Configuration{
FollowRequestTTL: time.Minute * 30,
FollowRequestSweepFreq: time.Minute,
ListMaxSize: 2000,
ListTTL: time.Minute * 30,
ListSweepFreq: time.Minute,
ListEntryMaxSize: 2000,
ListEntryTTL: time.Minute * 30,
ListEntrySweepFreq: time.Minute,
MediaMaxSize: 1000,
MediaTTL: time.Minute * 30,
MediaSweepFreq: time.Minute,

View file

@ -2778,6 +2778,156 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration {
// SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) }
// GetCacheGTSListMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListMaxSize' field
func (st *ConfigState) GetCacheGTSListMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListMaxSize
st.mutex.Unlock()
return
}
// SetCacheGTSListMaxSize safely sets the Configuration value for state's 'Cache.GTS.ListMaxSize' field
func (st *ConfigState) SetCacheGTSListMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListMaxSize = v
st.reloadToViper()
}
// CacheGTSListMaxSizeFlag returns the flag name for the 'Cache.GTS.ListMaxSize' field
func CacheGTSListMaxSizeFlag() string { return "cache-gts-list-max-size" }
// GetCacheGTSListMaxSize safely fetches the value for global configuration 'Cache.GTS.ListMaxSize' field
func GetCacheGTSListMaxSize() int { return global.GetCacheGTSListMaxSize() }
// SetCacheGTSListMaxSize safely sets the value for global configuration 'Cache.GTS.ListMaxSize' field
func SetCacheGTSListMaxSize(v int) { global.SetCacheGTSListMaxSize(v) }
// GetCacheGTSListTTL safely fetches the Configuration value for state's 'Cache.GTS.ListTTL' field
func (st *ConfigState) GetCacheGTSListTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListTTL
st.mutex.Unlock()
return
}
// SetCacheGTSListTTL safely sets the Configuration value for state's 'Cache.GTS.ListTTL' field
func (st *ConfigState) SetCacheGTSListTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListTTL = v
st.reloadToViper()
}
// CacheGTSListTTLFlag returns the flag name for the 'Cache.GTS.ListTTL' field
func CacheGTSListTTLFlag() string { return "cache-gts-list-ttl" }
// GetCacheGTSListTTL safely fetches the value for global configuration 'Cache.GTS.ListTTL' field
func GetCacheGTSListTTL() time.Duration { return global.GetCacheGTSListTTL() }
// SetCacheGTSListTTL safely sets the value for global configuration 'Cache.GTS.ListTTL' field
func SetCacheGTSListTTL(v time.Duration) { global.SetCacheGTSListTTL(v) }
// GetCacheGTSListSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.ListSweepFreq' field
func (st *ConfigState) GetCacheGTSListSweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListSweepFreq
st.mutex.Unlock()
return
}
// SetCacheGTSListSweepFreq safely sets the Configuration value for state's 'Cache.GTS.ListSweepFreq' field
func (st *ConfigState) SetCacheGTSListSweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListSweepFreq = v
st.reloadToViper()
}
// CacheGTSListSweepFreqFlag returns the flag name for the 'Cache.GTS.ListSweepFreq' field
func CacheGTSListSweepFreqFlag() string { return "cache-gts-list-sweep-freq" }
// GetCacheGTSListSweepFreq safely fetches the value for global configuration 'Cache.GTS.ListSweepFreq' field
func GetCacheGTSListSweepFreq() time.Duration { return global.GetCacheGTSListSweepFreq() }
// SetCacheGTSListSweepFreq safely sets the value for global configuration 'Cache.GTS.ListSweepFreq' field
func SetCacheGTSListSweepFreq(v time.Duration) { global.SetCacheGTSListSweepFreq(v) }
// GetCacheGTSListEntryMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListEntryMaxSize' field
func (st *ConfigState) GetCacheGTSListEntryMaxSize() (v int) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListEntryMaxSize
st.mutex.Unlock()
return
}
// SetCacheGTSListEntryMaxSize safely sets the Configuration value for state's 'Cache.GTS.ListEntryMaxSize' field
func (st *ConfigState) SetCacheGTSListEntryMaxSize(v int) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListEntryMaxSize = v
st.reloadToViper()
}
// CacheGTSListEntryMaxSizeFlag returns the flag name for the 'Cache.GTS.ListEntryMaxSize' field
func CacheGTSListEntryMaxSizeFlag() string { return "cache-gts-list-entry-max-size" }
// GetCacheGTSListEntryMaxSize safely fetches the value for global configuration 'Cache.GTS.ListEntryMaxSize' field
func GetCacheGTSListEntryMaxSize() int { return global.GetCacheGTSListEntryMaxSize() }
// SetCacheGTSListEntryMaxSize safely sets the value for global configuration 'Cache.GTS.ListEntryMaxSize' field
func SetCacheGTSListEntryMaxSize(v int) { global.SetCacheGTSListEntryMaxSize(v) }
// GetCacheGTSListEntryTTL safely fetches the Configuration value for state's 'Cache.GTS.ListEntryTTL' field
func (st *ConfigState) GetCacheGTSListEntryTTL() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListEntryTTL
st.mutex.Unlock()
return
}
// SetCacheGTSListEntryTTL safely sets the Configuration value for state's 'Cache.GTS.ListEntryTTL' field
func (st *ConfigState) SetCacheGTSListEntryTTL(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListEntryTTL = v
st.reloadToViper()
}
// CacheGTSListEntryTTLFlag returns the flag name for the 'Cache.GTS.ListEntryTTL' field
func CacheGTSListEntryTTLFlag() string { return "cache-gts-list-entry-ttl" }
// GetCacheGTSListEntryTTL safely fetches the value for global configuration 'Cache.GTS.ListEntryTTL' field
func GetCacheGTSListEntryTTL() time.Duration { return global.GetCacheGTSListEntryTTL() }
// SetCacheGTSListEntryTTL safely sets the value for global configuration 'Cache.GTS.ListEntryTTL' field
func SetCacheGTSListEntryTTL(v time.Duration) { global.SetCacheGTSListEntryTTL(v) }
// GetCacheGTSListEntrySweepFreq safely fetches the Configuration value for state's 'Cache.GTS.ListEntrySweepFreq' field
func (st *ConfigState) GetCacheGTSListEntrySweepFreq() (v time.Duration) {
st.mutex.Lock()
v = st.config.Cache.GTS.ListEntrySweepFreq
st.mutex.Unlock()
return
}
// SetCacheGTSListEntrySweepFreq safely sets the Configuration value for state's 'Cache.GTS.ListEntrySweepFreq' field
func (st *ConfigState) SetCacheGTSListEntrySweepFreq(v time.Duration) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.GTS.ListEntrySweepFreq = v
st.reloadToViper()
}
// CacheGTSListEntrySweepFreqFlag returns the flag name for the 'Cache.GTS.ListEntrySweepFreq' field
func CacheGTSListEntrySweepFreqFlag() string { return "cache-gts-list-entry-sweep-freq" }
// GetCacheGTSListEntrySweepFreq safely fetches the value for global configuration 'Cache.GTS.ListEntrySweepFreq' field
func GetCacheGTSListEntrySweepFreq() time.Duration { return global.GetCacheGTSListEntrySweepFreq() }
// SetCacheGTSListEntrySweepFreq safely sets the value for global configuration 'Cache.GTS.ListEntrySweepFreq' field
func SetCacheGTSListEntrySweepFreq(v time.Duration) { global.SetCacheGTSListEntrySweepFreq(v) }
// GetCacheGTSMediaMaxSize safely fetches the Configuration value for state's 'Cache.GTS.MediaMaxSize' field
func (st *ConfigState) GetCacheGTSMediaMaxSize() (v int) {
st.mutex.Lock()

View file

@ -65,6 +65,7 @@ type DBService struct {
db.Domain
db.Emoji
db.Instance
db.List
db.Media
db.Mention
db.Notification
@ -179,6 +180,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
Instance: &instanceDB{
conn: conn,
},
List: &listDB{
conn: conn,
state: state,
},
Media: &mediaDB{
conn: conn,
state: state,

View file

@ -22,6 +22,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -46,6 +47,8 @@ type BunDBStandardTestSuite struct {
testReports map[string]*gtsmodel.Report
testBookmarks map[string]*gtsmodel.StatusBookmark
testFaves map[string]*gtsmodel.StatusFave
testLists map[string]*gtsmodel.List
testListEntries map[string]*gtsmodel.ListEntry
}
func (suite *BunDBStandardTestSuite) SetupSuite() {
@ -63,6 +66,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testReports = testrig.NewTestReports()
suite.testBookmarks = testrig.NewTestBookmarks()
suite.testFaves = testrig.NewTestFaves()
suite.testLists = testrig.NewTestLists()
suite.testListEntries = testrig.NewTestListEntries()
}
func (suite *BunDBStandardTestSuite) SetupTest() {
@ -70,6 +75,7 @@ func (suite *BunDBStandardTestSuite) SetupTest() {
testrig.InitTestLog()
suite.state.Caches.Init()
suite.db = testrig.NewTestDB(&suite.state)
testrig.StartTimelines(&suite.state, visibility.NewFilter(&suite.state), testrig.NewTestTypeConverter(suite.db))
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}

467
internal/db/bundb/list.go Normal file
View file

@ -0,0 +1,467 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
type listDB struct {
conn *DBConn
state *state.State
}
/*
LIST FUNCTIONS
*/
func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {
list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) {
var list gtsmodel.List
// Not cached! Perform database query.
if err := dbQuery(&list); err != nil {
return nil, l.conn.ProcessError(err)
}
return &list, nil
}, keyParts...)
if err != nil {
return nil, err // already processed
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return list, nil
}
if err := l.state.DB.PopulateList(ctx, list); err != nil {
return nil, err
}
return list, nil
}
func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, error) {
return l.getList(
ctx,
"ID",
func(list *gtsmodel.List) error {
return l.conn.NewSelect().
Model(list).
Where("? = ?", bun.Ident("list.id"), id).
Scan(ctx)
},
id,
)
}
func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error) {
// Fetch IDs of all lists owned by this account.
var listIDs []string
if err := l.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("lists"), bun.Ident("list")).
Column("list.id").
Where("? = ?", bun.Ident("list.account_id"), accountID).
Order("list.id DESC").
Scan(ctx, &listIDs); err != nil {
return nil, l.conn.ProcessError(err)
}
if len(listIDs) == 0 {
return nil, nil
}
// Select each list using its ID to ensure cache used.
lists := make([]*gtsmodel.List, 0, len(listIDs))
for _, id := range listIDs {
list, err := l.state.DB.GetListByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list %q: %v", id, err)
continue
}
// Append list.
lists = append(lists, list)
}
return lists, nil
}
func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
var (
err error
errs = make(gtserror.MultiError, 0, 2)
)
if list.Account == nil {
// List account is not set, fetch from the database.
list.Account, err = l.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
list.AccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating list account: %w", err))
}
}
if list.ListEntries == nil {
// List entries are not set, fetch from the database.
list.ListEntries, err = l.state.DB.GetListEntries(
gtscontext.SetBarebones(ctx),
list.ID,
"", "", "", 0,
)
if err != nil {
errs.Append(fmt.Errorf("error populating list entries: %w", err))
}
}
return errs.Combine()
}
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
return l.state.Caches.GTS.List().Store(list, func() error {
_, err := l.conn.NewInsert().Model(list).Exec(ctx)
return l.conn.ProcessError(err)
})
}
func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error {
list.UpdatedAt = time.Now()
if len(columns) > 0 {
// If we're updating by column, ensure "updated_at" is included.
columns = append(columns, "updated_at")
}
return l.state.Caches.GTS.List().Store(list, func() error {
if _, err := l.conn.NewUpdate().
Model(list).
Where("? = ?", bun.Ident("list.id"), list.ID).
Column(columns...).
Exec(ctx); err != nil {
return l.conn.ProcessError(err)
}
return nil
})
}
func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
defer l.state.Caches.GTS.List().Invalidate("ID", id)
// Select all entries that belong to this list.
listEntries, err := l.state.DB.GetListEntries(ctx, id, "", "", "", 0)
if err != nil {
return fmt.Errorf("error selecting entries from list %q: %w", id, err)
}
// Delete each list entry. This will
// invalidate the list timeline too.
for _, listEntry := range listEntries {
err := l.state.DB.DeleteListEntry(ctx, listEntry.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
}
// Finally delete list itself from DB.
_, err = l.conn.NewDelete().
Table("lists").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return l.conn.ProcessError(err)
}
/*
LIST ENTRY functions
*/
func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {
listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) {
var listEntry gtsmodel.ListEntry
// Not cached! Perform database query.
if err := dbQuery(&listEntry); err != nil {
return nil, l.conn.ProcessError(err)
}
return &listEntry, nil
}, keyParts...)
if err != nil {
return nil, err // already processed
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return listEntry, nil
}
// Further populate the list entry fields where applicable.
if err := l.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
return nil, err
}
return listEntry, nil
}
func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error) {
return l.getListEntry(
ctx,
"ID",
func(listEntry *gtsmodel.ListEntry) error {
return l.conn.NewSelect().
Model(listEntry).
Where("? = ?", bun.Ident("list_entry.id"), id).
Scan(ctx)
},
id,
)
}
func (l *listDB) GetListEntries(ctx context.Context,
listID string,
maxID string,
sinceID string,
minID string,
limit int,
) ([]*gtsmodel.ListEntry, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
var (
entryIDs = make([]string, 0, limit)
frontToBack = true
)
q := l.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
// Select only IDs from table
Column("entry.id").
// Select only entries belonging to listID.
Where("? = ?", bun.Ident("entry.list_id"), listID)
if maxID != "" {
// return only entries LOWER (ie., older) than maxID
q = q.Where("? < ?", bun.Ident("entry.id"), maxID)
}
if sinceID != "" {
// return only entries HIGHER (ie., newer) than sinceID
q = q.Where("? > ?", bun.Ident("entry.id"), sinceID)
}
if minID != "" {
// return only entries HIGHER (ie., newer) than minID
q = q.Where("? > ?", bun.Ident("entry.id"), minID)
// page up
frontToBack = false
}
if limit > 0 {
// limit amount of entries returned
q = q.Limit(limit)
}
if frontToBack {
// Page down.
q = q.Order("entry.id DESC")
} else {
// Page up.
q = q.Order("entry.id ASC")
}
if err := q.Scan(ctx, &entryIDs); err != nil {
return nil, l.conn.ProcessError(err)
}
if len(entryIDs) == 0 {
return nil, nil
}
// If we're paging up, we still want entries
// to be sorted by ID desc, so reverse ids slice.
// https://zchee.github.io/golang-wiki/SliceTricks/#reversing
if !frontToBack {
for l, r := 0, len(entryIDs)-1; l < r; l, r = l+1, r-1 {
entryIDs[l], entryIDs[r] = entryIDs[r], entryIDs[l]
}
}
// Select each list entry using its ID to ensure cache used.
listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
for _, id := range entryIDs {
listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
continue
}
// Append list entries.
listEntries = append(listEntries, listEntry)
}
return listEntries, nil
}
func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
entryIDs := []string{}
if err := l.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("entry")).
// Select only IDs from table
Column("entry.id").
// Select only entries belonging with given followID.
Where("? = ?", bun.Ident("entry.follow_id"), followID).
Scan(ctx, &entryIDs); err != nil {
return nil, l.conn.ProcessError(err)
}
if len(entryIDs) == 0 {
return nil, nil
}
// Select each list entry using its ID to ensure cache used.
listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
for _, id := range entryIDs {
listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
continue
}
// Append list entries.
listEntries = append(listEntries, listEntry)
}
return listEntries, nil
}
func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error {
var err error
if listEntry.Follow == nil {
// ListEntry follow is not set, fetch from the database.
listEntry.Follow, err = l.state.DB.GetFollowByID(
gtscontext.SetBarebones(ctx),
listEntry.FollowID,
)
if err != nil {
return fmt.Errorf("error populating listEntry follow: %w", err)
}
}
return nil
}
func (l *listDB) PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error {
return l.conn.RunInTx(ctx, func(tx bun.Tx) error {
for _, listEntry := range listEntries {
if _, err := tx.
NewInsert().
Model(listEntry).
Exec(ctx); err != nil {
return err
}
// Invalidate the timeline for the list this entry belongs to.
if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil {
log.Errorf(ctx, "PutListEntries: error invalidating list timeline: %q", err)
}
}
return nil
})
}
func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
defer l.state.Caches.GTS.ListEntry().Invalidate("ID", id)
// Load list entry into cache before attempting a delete,
// as we need the followID from it in order to trigger
// timeline invalidation.
listEntry, err := l.GetListEntryByID(
// Don't populate the entry;
// we only want the list ID.
gtscontext.SetBarebones(ctx),
id,
)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// Already gone.
return nil
}
return err
}
defer func() {
// Invalidate the timeline for the list this entry belongs to.
if err := l.state.Timelines.List.RemoveTimeline(ctx, listEntry.ListID); err != nil {
log.Errorf(ctx, "DeleteListEntry: error invalidating list timeline: %q", err)
}
}()
if _, err := l.conn.NewDelete().
Table("list_entries").
Where("? = ?", bun.Ident("id"), listEntry.ID).
Exec(ctx); err != nil {
return l.conn.ProcessError(err)
}
return nil
}
func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID string) error {
// Fetch IDs of all entries that pertain to this follow.
var listEntryIDs []string
if err := l.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("list_entries"), bun.Ident("list_entry")).
Column("list_entry.id").
Where("? = ?", bun.Ident("list_entry.follow_id"), followID).
Order("list_entry.id DESC").
Scan(ctx, &listEntryIDs); err != nil {
return l.conn.ProcessError(err)
}
for _, id := range listEntryIDs {
if err := l.DeleteListEntry(ctx, id); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,315 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"golang.org/x/exp/slices"
)
type ListTestSuite struct {
BunDBStandardTestSuite
}
func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) {
testList := &gtsmodel.List{}
*testList = *suite.testLists["local_account_1_list_1"]
// Populate entries on this list as we'd expect them back from the db.
entries := make([]*gtsmodel.ListEntry, 0, len(suite.testListEntries))
for _, entry := range suite.testListEntries {
entries = append(entries, entry)
}
// Sort by ID descending (again, as we'd expect from the db).
slices.SortFunc(entries, func(a, b *gtsmodel.ListEntry) bool {
return b.ID < a.ID
})
testList.ListEntries = entries
testAccount := &gtsmodel.Account{}
*testAccount = *suite.testAccounts["local_account_1"]
return testList, testAccount
}
func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) {
suite.Equal(expected.ID, actual.ID)
suite.Equal(expected.Title, actual.Title)
suite.Equal(expected.AccountID, actual.AccountID)
suite.Equal(expected.RepliesPolicy, actual.RepliesPolicy)
suite.NotNil(actual.Account)
}
func (suite *ListTestSuite) checkListEntry(expected *gtsmodel.ListEntry, actual *gtsmodel.ListEntry) {
suite.Equal(expected.ID, actual.ID)
suite.Equal(expected.ListID, actual.ListID)
suite.Equal(expected.FollowID, actual.FollowID)
}
func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, actual []*gtsmodel.ListEntry) {
var (
lExpected = len(expected)
lActual = len(actual)
)
if lExpected != lActual {
suite.FailNow("", "expected %d list entries, got %d", lExpected, lActual)
}
var topID string
for i, expectedEntry := range expected {
actualEntry := actual[i]
// Ensure ID descending.
if topID == "" {
topID = actualEntry.ID
} else {
suite.Less(actualEntry.ID, topID)
}
suite.checkListEntry(expectedEntry, actualEntry)
}
}
func (suite *ListTestSuite) TestGetListByID() {
testList, _ := suite.testStructs()
dbList, err := suite.db.GetListByID(context.Background(), testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkList(testList, dbList)
suite.checkListEntries(testList.ListEntries, dbList.ListEntries)
}
func (suite *ListTestSuite) TestGetListsForAccountID() {
testList, testAccount := suite.testStructs()
dbLists, err := suite.db.GetListsForAccountID(context.Background(), testAccount.ID)
if err != nil {
suite.FailNow(err.Error())
}
if l := len(dbLists); l != 1 {
suite.FailNow("", "expected %d lists, got %d", 1, l)
}
suite.checkList(testList, dbLists[0])
}
func (suite *ListTestSuite) TestGetListEntries() {
testList, _ := suite.testStructs()
dbListEntries, err := suite.db.GetListEntries(context.Background(), testList.ID, "", "", "", 0)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkListEntries(testList.ListEntries, dbListEntries)
}
func (suite *ListTestSuite) TestPutList() {
ctx := context.Background()
_, testAccount := suite.testStructs()
testList := &gtsmodel.List{
ID: "01H0J2PMYM54618VCV8Y8QYAT4",
Title: "Test List!",
AccountID: testAccount.ID,
}
if err := suite.db.PutList(ctx, testList); err != nil {
suite.FailNow(err.Error())
}
dbList, err := suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
// Bodge testlist as though default had been set.
testList.RepliesPolicy = gtsmodel.RepliesPolicyFollowed
suite.checkList(testList, dbList)
}
func (suite *ListTestSuite) TestUpdateList() {
ctx := context.Background()
testList, _ := suite.testStructs()
// Get List in the cache first.
dbList, err := suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
// Now do the update.
testList.Title = "New Title!"
if err := suite.db.UpdateList(ctx, testList, "title"); err != nil {
suite.FailNow(err.Error())
}
// Cache should be invalidated
// + we should have updated list.
dbList, err = suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkList(testList, dbList)
}
func (suite *ListTestSuite) TestDeleteList() {
ctx := context.Background()
testList, _ := suite.testStructs()
// Get List in the cache first.
if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
suite.FailNow(err.Error())
}
// Now do the delete.
if err := suite.db.DeleteListByID(ctx, testList.ID); err != nil {
suite.FailNow(err.Error())
}
// Cache should be invalidated
// + we should have no list.
_, err := suite.db.GetListByID(ctx, testList.ID)
suite.ErrorIs(err, db.ErrNoEntries)
// All entries belonging to this
// list should now be deleted.
listEntries, err := suite.db.GetListEntries(ctx, testList.ID, "", "", "", 0)
if err != nil {
suite.FailNow(err.Error())
}
suite.Empty(listEntries)
}
func (suite *ListTestSuite) TestPutListEntries() {
ctx := context.Background()
testList, _ := suite.testStructs()
listEntries := []*gtsmodel.ListEntry{
{
ID: "01H0MKMQY69HWDSDR2SWGA17R4",
ListID: testList.ID,
FollowID: "01H0MKNFRFZS8R9WV6DBX31Y03", // random id, doesn't exist
},
{
ID: "01H0MKPGQF0E7QAVW5BKTHZ630",
ListID: testList.ID,
FollowID: "01H0MKP6RR8VEHN3GVWFBP2H30", // random id, doesn't exist
},
{
ID: "01H0MKPPP2DT68FRBMR1FJM32T",
ListID: testList.ID,
FollowID: "01H0MKQ0KA29C6NFJ27GTZD16J", // random id, doesn't exist
},
}
if err := suite.db.PutListEntries(ctx, listEntries); err != nil {
suite.FailNow(err.Error())
}
// Add these entries to the test list, sort it again
// to reflect what we'd expect to get from the db.
testList.ListEntries = append(testList.ListEntries, listEntries...)
slices.SortFunc(testList.ListEntries, func(a, b *gtsmodel.ListEntry) bool {
return b.ID < a.ID
})
// Now get all list entries from the db.
// Use barebones for this because the ones
// we just added will fail if we try to get
// the nonexistent follows.
dbListEntries, err := suite.db.GetListEntries(
gtscontext.SetBarebones(ctx),
testList.ID,
"", "", "", 0)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkListEntries(testList.ListEntries, dbListEntries)
}
func (suite *ListTestSuite) TestDeleteListEntry() {
ctx := context.Background()
testList, _ := suite.testStructs()
// Get List in the cache first.
if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
suite.FailNow(err.Error())
}
// Delete the first entry.
if err := suite.db.DeleteListEntry(ctx, testList.ListEntries[0].ID); err != nil {
suite.FailNow(err.Error())
}
// Get list from the db again.
dbList, err := suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
// Bodge the testlist as though
// we'd removed the first entry.
testList.ListEntries = testList.ListEntries[1:]
suite.checkList(testList, dbList)
}
func (suite *ListTestSuite) TestDeleteListEntriesForFollowID() {
ctx := context.Background()
testList, _ := suite.testStructs()
// Get List in the cache first.
if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil {
suite.FailNow(err.Error())
}
// Delete the first entry.
if err := suite.db.DeleteListEntriesForFollowID(ctx, testList.ListEntries[0].FollowID); err != nil {
suite.FailNow(err.Error())
}
// Get list from the db again.
dbList, err := suite.db.GetListByID(ctx, testList.ID)
if err != nil {
suite.FailNow(err.Error())
}
// Bodge the testlist as though
// we'd removed the first entry.
testList.ListEntries = testList.ListEntries[1:]
suite.checkList(testList, dbList)
}
func TestListTestSuite(t *testing.T) {
suite.Run(t, new(ListTestSuite))
}

View file

@ -0,0 +1,92 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// List table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.List{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Add indexes to the List table.
for index, columns := range map[string][]string{
"lists_id_idx": {"id"},
"lists_account_id_idx": {"account_id"},
} {
if _, err := tx.
NewCreateIndex().
Table("lists").
Index(index).
Column(columns...).
Exec(ctx); err != nil {
return err
}
}
// List entry table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.ListEntry{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Add indexes to the List entry table.
for index, columns := range map[string][]string{
"list_entries_id_idx": {"id"},
"list_entries_list_id_idx": {"list_id"},
"list_entries_follow_id_idx": {"follow_id"},
} {
if _, err := tx.
NewCreateIndex().
Table("list_entries").
Index(index).
Column(columns...).
Exec(ctx); err != nil {
return err
}
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
@ -149,27 +150,44 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f
return follow, nil
}
// Set the follow source account
follow.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow source account: %w", err)
}
// Set the follow target account
follow.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow target account: %w", err)
if err := r.state.DB.PopulateFollow(ctx, follow); err != nil {
return nil, err
}
return follow, nil
}
func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error {
var (
err error
errs = make(gtserror.MultiError, 0, 2)
)
if follow.Account == nil {
// Follow account is not set, fetch from the database.
follow.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.AccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating follow account: %w", err))
}
}
if follow.TargetAccount == nil {
// Follow target account is not set, fetch from the database.
follow.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.TargetAccountID,
)
if err != nil {
errs.Append(fmt.Errorf("error populating follow target account: %w", err))
}
}
return errs.Combine()
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
return r.state.Caches.GTS.Follow().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
@ -197,27 +215,40 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
})
}
func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error {
// Delete the follow itself using the given ID.
if _, err := r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Delete every list entry that used this followID.
if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {
return fmt.Errorf("deleteFollow: error deleting list entries: %w", err)
}
return nil
}
func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
_, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id)
follow, err := r.GetFollowByID(gtscontext.SetBarebones(ctx), id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// not an issue.
err = nil
// Already gone.
return nil
}
return err
}
// Finally delete follow from DB.
_, err = r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return r.conn.ProcessError(err)
return r.deleteFollow(ctx, follow.ID)
}
func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
@ -226,21 +257,17 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro
// Load follow into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
_, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri)
follow, err := r.GetFollowByURI(gtscontext.SetBarebones(ctx), uri)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// not an issue.
err = nil
// Already gone.
return nil
}
return err
}
// Finally delete follow from DB.
_, err = r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx)
return r.conn.ProcessError(err)
return r.deleteFollow(ctx, follow.ID)
}
func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
@ -272,16 +299,16 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
for _, id := range followIDs {
_, err := r.GetFollowByID(ctx, id)
follow, err := r.GetFollowByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Delete each follow from DB.
if err := r.deleteFollow(ctx, follow.ID); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
}
// Finally delete all from DB.
_, err := r.conn.NewDelete().
Table("follows").
Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)).
Exec(ctx)
return r.conn.ProcessError(err)
return nil
}

View file

@ -807,16 +807,27 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() {
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.NotNil(follow)
followID := follow.ID
err = suite.db.DeleteFollowByID(context.Background(), follow.ID)
// We should have list entries for this follow.
listEntries, err := suite.db.GetListEntriesForFollowID(context.Background(), followID)
suite.NoError(err)
suite.NotEmpty(listEntries)
err = suite.db.DeleteFollowByID(context.Background(), followID)
suite.NoError(err)
follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(follow)
// ListEntries pertaining to this follow should be deleted too.
listEntries, err = suite.db.GetListEntriesForFollowID(context.Background(), followID)
suite.NoError(err)
suite.Empty(listEntries)
}
func (suite *RelationshipTestSuite) TestUnfollowNotExisting() {
func (suite *RelationshipTestSuite) TestGetFollowNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"

View file

@ -19,9 +19,11 @@ package bundb
import (
"context"
"fmt"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
@ -281,3 +283,130 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
prevMinID := faves[0].ID
return statuses, nextMaxID, prevMinID, nil
}
func (t *timelineDB) GetListTimeline(
ctx context.Context,
listID string,
maxID string,
sinceID string,
minID string,
limit int,
) ([]*gtsmodel.Status, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
var (
statusIDs = make([]string, 0, limit)
frontToBack = true
)
// Fetch all listEntries entries from the database.
listEntries, err := t.state.DB.GetListEntries(
// Don't need actual follows
// for this, just the IDs.
gtscontext.SetBarebones(ctx),
listID,
"", "", "", 0,
)
if err != nil {
return nil, fmt.Errorf("error getting entries for list %s: %w", listID, err)
}
// Extract just the IDs of each follow.
followIDs := make([]string, 0, len(listEntries))
for _, listEntry := range listEntries {
followIDs = append(followIDs, listEntry.FollowID)
}
// Select target account IDs from follows.
subQ := t.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.target_account_id").
Where("? IN (?)", bun.Ident("follow.id"), bun.In(followIDs))
// Select only status IDs created
// by one of the followed accounts.
q := t.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Select only IDs from table
Column("status.id").
Where("? IN (?)", bun.Ident("status.account_id"), subQ)
if maxID == "" || maxID >= id.Highest {
const future = 24 * time.Hour
var err error
// don't return statuses more than 24hr in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}
}
// return only statuses LOWER (ie., older) than maxID
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
if sinceID != "" {
// return only statuses HIGHER (ie., newer) than sinceID
q = q.Where("? > ?", bun.Ident("status.id"), sinceID)
}
if minID != "" {
// return only statuses HIGHER (ie., newer) than minID
q = q.Where("? > ?", bun.Ident("status.id"), minID)
// page up
frontToBack = false
}
if limit > 0 {
// limit amount of statuses returned
q = q.Limit(limit)
}
if frontToBack {
// Page down.
q = q.Order("status.id DESC")
} else {
// Page up.
q = q.Order("status.id ASC")
}
if err := q.Scan(ctx, &statusIDs); err != nil {
return nil, t.conn.ProcessError(err)
}
if len(statusIDs) == 0 {
return nil, nil
}
// If we're paging up, we still want statuses
// to be sorted by ID desc, so reverse ids slice.
// https://zchee.github.io/golang-wiki/SliceTricks/#reversing
if !frontToBack {
for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 {
statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l]
}
}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
for _, id := range statusIDs {
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
}

View file

@ -33,99 +33,6 @@ type TimelineTestSuite struct {
BunDBStandardTestSuite
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
var count int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
count++
}
}
ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
var count int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
count++
}
}
ctx := context.Background()
futureStatus := getFutureStatus()
err := suite.db.PutStatus(ctx, futureStatus)
suite.NoError(err)
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
suite.NotContains(s, futureStatus)
suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"]
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
suite.NoError(err)
suite.Len(s, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"]
futureStatus := getFutureStatus()
err := suite.db.PutStatus(ctx, futureStatus)
suite.NoError(err)
s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false)
suite.NoError(err)
suite.NotContains(s, futureStatus)
suite.Len(s, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"]
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", id.Lowest, 5, false)
suite.NoError(err)
suite.Len(s, 5)
suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", s[0].ID)
suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {
ctx := context.Background()
viewingAccount := suite.testAccounts["local_account_1"]
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, id.Highest, "", "", 5, false)
suite.NoError(err)
suite.Len(s, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID)
}
func getFutureStatus() *gtsmodel.Status {
theDistantFuture := time.Now().Add(876600 * time.Hour)
id, err := id.NewULIDFromTime(theDistantFuture)
@ -163,6 +70,208 @@ func getFutureStatus() *gtsmodel.Status {
}
}
func (suite *TimelineTestSuite) publicCount() int {
var publicCount int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
publicCount++
}
}
return publicCount
}
func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) {
if l := len(statuses); l != expectedLength {
suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l)
} else if l == 0 {
// Can't test empty slice.
return
}
// Check ordering + bounds of statuses.
highest := statuses[0].ID
for _, status := range statuses {
id := status.ID
if id >= maxID {
suite.FailNow("", "%s greater than maxID %s", id, maxID)
}
if id <= minID {
suite.FailNow("", "%s smaller than minID %s", id, minID)
}
if id > highest {
suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID")
}
highest = id
}
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
ctx := context.Background()
// Insert a status set far in the
// future, it shouldn't be retrieved.
futureStatus := getFutureStatus()
if err := suite.db.PutStatus(ctx, futureStatus); err != nil {
suite.FailNow(err.Error())
}
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotContains(s, futureStatus)
suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
var (
ctx = context.Background()
viewingAccount = suite.testAccounts["local_account_1"]
)
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
var (
ctx = context.Background()
viewingAccount = suite.testAccounts["local_account_1"]
)
// Insert a status set far in the
// future, it shouldn't be retrieved.
futureStatus := getFutureStatus()
if err := suite.db.PutStatus(ctx, futureStatus); err != nil {
suite.FailNow(err.Error())
}
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", "", 20, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotContains(s, futureStatus)
suite.checkStatuses(s, id.Highest, id.Lowest, 16)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() {
var (
ctx = context.Background()
viewingAccount = suite.testAccounts["local_account_1"]
)
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, "", "", id.Lowest, 5, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01F8MHAYFKS4KMXF8K5Y1C0KRN", s[0].ID)
suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {
var (
ctx = context.Background()
viewingAccount = suite.testAccounts["local_account_1"]
)
s, err := suite.db.GetHomeTimeline(ctx, viewingAccount.ID, id.Highest, "", "", 5, false)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
var (
ctx = context.Background()
list = suite.testLists["local_account_1_list_1"]
)
s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "", 20)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 11)
}
func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
var (
ctx = context.Background()
list = suite.testLists["local_account_1_list_1"]
)
s, err := suite.db.GetListTimeline(ctx, list.ID, id.Highest, "", "", 5)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineMinID() {
var (
ctx = context.Background()
list = suite.testLists["local_account_1_list_1"]
)
s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", id.Lowest, 5)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01F8MHC8VWDRBQR0N1BATDDEM5", s[0].ID)
suite.Equal("01F8MH75CBF9JFX4ZAD54N0W0R", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineMinIDPagingUp() {
var (
ctx = context.Background()
list = suite.testLists["local_account_1_list_1"]
)
s, err := suite.db.GetListTimeline(ctx, list.ID, "", "", "01F8MHC8VWDRBQR0N1BATDDEM5", 5)
if err != nil {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, "01F8MHC8VWDRBQR0N1BATDDEM5", 5)
suite.Equal("01G20ZM733MGN8J344T4ZDDFY1", s[0].ID)
suite.Equal("01F8MHCP5P2NWYQ416SBA0XSEV", s[len(s)-1].ID)
}
func TestTimelineTestSuite(t *testing.T) {
suite.Run(t, new(TimelineTestSuite))
}

View file

@ -36,6 +36,7 @@ type DB interface {
Domain
Emoji
Instance
List
Media
Mention
Notification

67
internal/db/list.go Normal file
View file

@ -0,0 +1,67 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type List interface {
// GetListByID gets one list with the given id.
GetListByID(ctx context.Context, id string) (*gtsmodel.List, error)
// GetListsForAccountID gets all lists owned by the given accountID.
GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error)
// PopulateList ensures that the list's struct fields are populated.
PopulateList(ctx context.Context, list *gtsmodel.List) error
// PutList puts a new list in the database.
PutList(ctx context.Context, list *gtsmodel.List) error
// UpdateList updates the given list.
// Columns is optional, if not specified all will be updated.
UpdateList(ctx context.Context, list *gtsmodel.List, columns ...string) error
// DeleteListByID deletes one list with the given ID.
DeleteListByID(ctx context.Context, id string) error
// GetListEntryByID gets one list entry with the given ID.
GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error)
// GetListEntries gets list entries from the given listID, using the given parameters.
GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error)
// GetListEntriesForFollowID returns all listEntries that pertain to the given followID.
GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error)
// PopulateListEntry ensures that the listEntry's struct fields are populated.
PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error
// PutListEntries inserts a slice of listEntries into the database.
// It uses a transaction to ensure no partial updates.
PutListEntries(ctx context.Context, listEntries []*gtsmodel.ListEntry) error
// DeleteListEntry deletes one list entry with the given id.
DeleteListEntry(ctx context.Context, id string) error
// DeleteListEntryForFollowID deletes all list entries with the given followID.
DeleteListEntriesForFollowID(ctx context.Context, followID string) error
}

View file

@ -64,6 +64,9 @@ type Relationship interface {
// GetFollow retrieves a follow if it exists between source and target accounts.
GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// PopulateFollow populates the struct pointers on the given follow.
PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error
// GetFollowRequestByID fetches follow request with given ID from the database.
GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error)

View file

@ -44,4 +44,8 @@ type Timeline interface {
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
// GetListTimeline returns a slice of statuses from followed accounts collected within the list with the given listID.
// Statuses should be returned in descending order of when they were created (newest first).
GetListTimeline(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error)
}

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -61,6 +62,13 @@ func (suite *DereferencerStandardTestSuite) SetupTest() {
testrig.StartWorkers(&suite.state)
suite.db = testrig.NewTestDB(&suite.state)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.storage = testrig.NewInMemoryStorage()
suite.state.DB = suite.db
suite.state.Storage = suite.storage

View file

@ -28,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -76,8 +77,16 @@ func (suite *FederatingDBTestSuite) SetupTest() {
}
suite.db = testrig.NewTestDB(&suite.state)
suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.federatingDB = testrig.NewTestFederatingDB(&suite.state)
testrig.StandardDBSetup(suite.db, suite.testAccounts)

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -58,6 +59,13 @@ func (suite *FederatorStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
suite.state.DB = suite.db
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}

51
internal/gtsmodel/list.go Normal file
View file

@ -0,0 +1,51 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtsmodel
import "time"
// List refers to a list of follows for which the owning account wants to view a timeline of posts.
type List struct {
ID string `validate:"required,ulid" bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
Title string `validate:"required" bun:",nullzero,notnull,unique:listaccounttitle"` // Title of this list.
AccountID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listaccounttitle"` // Account that created/owns the list
Account *Account `validate:"-" bun:"-"` // Account corresponding to accountID
ListEntries []*ListEntry `validate:"-" bun:"-"` // Entries contained by this list.
RepliesPolicy RepliesPolicy `validate:"-" bun:",nullzero,notnull,default:'followed'"` // RepliesPolicy for this list.
}
// ListEntry refers to a single follow entry in a list.
type ListEntry struct {
ID string `validate:"required,ulid" bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database
CreatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
ListID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listentrylistfollow"` // ID of the list that this entry belongs to.
FollowID string `validate:"required,ulid" bun:"type:CHAR(26),notnull,nullzero,unique:listentrylistfollow"` // Follow that the account owning this entry wants to see posts of in the timeline.
Follow *Follow `validate:"-" bun:"-"` // Follow corresponding to followID.
}
// RepliesPolicy denotes which replies should be shown in the list.
type RepliesPolicy string
const (
RepliesPolicyFollowed RepliesPolicy = "followed" // Show replies to any followed user.
RepliesPolicyList RepliesPolicy = "list" // Show replies to members of the list only.
RepliesPolicyNone RepliesPolicy = "none" // Don't show replies.
)

View file

@ -33,6 +33,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/state"
gtsstorage "github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type ManagerTestSuite struct {
@ -395,9 +396,6 @@ func (suite *ManagerTestSuite) TestSlothVineProcessBlocking() {
// fetch the attachment id from the processing media
attachmentID := processingMedia.AttachmentID()
// Give time for processing
time.Sleep(time.Second * 3)
// do a blocking call to fetch the attachment
attachment, err := processingMedia.LoadAttachment(ctx)
suite.NoError(err)
@ -1027,13 +1025,14 @@ func (suite *ManagerTestSuite) TestSimpleJpegProcessAsync() {
// fetch the attachment id from the processing media
attachmentID := processingMedia.AttachmentID()
// Give time for processing to happen.
time.Sleep(time.Second * 3)
// fetch the attachment from the database
attachment, err := suite.db.GetAttachmentByID(ctx, attachmentID)
suite.NoError(err)
suite.NotNil(attachment)
// wait for processing to complete
var attachment *gtsmodel.MediaAttachment
if !testrig.WaitFor(func() bool {
attachment, err = suite.db.GetAttachmentByID(ctx, attachmentID)
return err == nil && attachment != nil
}) {
suite.FailNow("timed out waiting for attachment to process")
}
// make sure it's got the stuff set on it that we expect
// the attachment ID and accountID we expect

View file

@ -25,6 +25,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -41,22 +42,27 @@ type MediaStandardTestSuite struct {
testEmojis map[string]*gtsmodel.Emoji
}
func (suite *MediaStandardTestSuite) SetupSuite() {
func (suite *MediaStandardTestSuite) SetupTest() {
testrig.InitTestConfig()
testrig.InitTestLog()
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
suite.db = testrig.NewTestDB(&suite.state)
suite.storage = testrig.NewInMemoryStorage()
suite.state.DB = suite.db
suite.state.Storage = suite.storage
}
func (suite *MediaStandardTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.StartWorkers(&suite.state)
testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
testrig.StandardDBSetup(suite.db, nil)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
testrig.NewTestTypeConverter(suite.db),
)
suite.testAttachments = testrig.NewTestAttachments()
suite.testAccounts = testrig.NewTestAccounts()
suite.testEmojis = testrig.NewTestEmojis()

View file

@ -88,6 +88,13 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.tc = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.tc,
)
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)

View file

@ -0,0 +1,107 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package account
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
var noLists = make([]*apimodel.List, 0)
// ListsGet returns all lists owned by requestingAccount, which contain a follow for targetAccountID.
func (p *Processor) ListsGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]*apimodel.List, gtserror.WithCode) {
targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
}
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err))
}
visible, err := p.filter.AccountVisible(ctx, requestingAccount, targetAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err))
}
if !visible {
return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
}
// Requester has to follow targetAccount
// for them to be in any of their lists.
follow, err := p.state.DB.GetFollow(
// Don't populate follow.
gtscontext.SetBarebones(ctx),
requestingAccount.ID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err))
}
if follow == nil {
return noLists, nil // by definition we know they're in no lists
}
listEntries, err := p.state.DB.GetListEntriesForFollowID(
// Don't populate entries.
gtscontext.SetBarebones(ctx),
follow.ID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error: %w", err))
}
count := len(listEntries)
if count == 0 {
return noLists, nil
}
apiLists := make([]*apimodel.List, 0, count)
for _, listEntry := range listEntries {
list, err := p.state.DB.GetListByID(
// Don't populate list.
gtscontext.SetBarebones(ctx),
listEntry.ListID,
)
if err != nil {
log.Debugf(ctx, "skipping list %s due to error %q", listEntry.ListID, err)
continue
}
apiList, err := p.tc.ListToAPIList(ctx, list)
if err != nil {
log.Debugf(ctx, "skipping list %s due to error %q", listEntry.ListID, err)
continue
}
apiLists = append(apiLists, apiList)
}
return apiLists, nil
}

View file

@ -217,10 +217,10 @@ func (p *Processor) processCreateBlockFromClientAPI(ctx context.Context, clientM
}
// remove any of the blocking account's statuses from the blocked account's timeline, and vice versa
if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
return err
}
if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
return err
}

View file

@ -20,6 +20,7 @@ package processing_test
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/stretchr/testify/suite"
@ -36,24 +37,21 @@ type FromClientAPITestSuite struct {
ProcessingStandardTestSuite
}
// This test ensures that when admin_account posts a new
// status, it ends up in the correct streaming timelines
// of local_account_1, which follows it.
func (suite *FromClientAPITestSuite) TestProcessStreamNewStatus() {
ctx := context.Background()
var (
ctx = context.Background()
postingAccount = suite.testAccounts["admin_account"]
receivingAccount = suite.testAccounts["local_account_1"]
testList = suite.testLists["local_account_1_list_1"]
streams = suite.openStreams(ctx, receivingAccount, []string{testList.ID})
homeStream = streams[stream.TimelineHome]
listStream = streams[stream.TimelineList+":"+testList.ID]
)
// let's say that the admin account posts a new status: it should end up in the
// timeline of any account that follows it and has a stream open
postingAccount := suite.testAccounts["admin_account"]
receivingAccount := suite.testAccounts["local_account_1"]
// open a home timeline stream for zork
wssStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelineHome)
suite.NoError(errWithCode)
// open another stream for zork, but for a different timeline;
// this shouldn't get stuff streamed into it, since it's for the public timeline
irrelevantStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelinePublic)
suite.NoError(errWithCode)
// make a new status from admin account
// Make a new status from admin account.
newStatus := &gtsmodel.Status{
ID: "01FN4B2F88TF9676DYNXWE1WSS",
URI: "http://localhost:8080/users/admin/statuses/01FN4B2F88TF9676DYNXWE1WSS",
@ -82,87 +80,110 @@ func (suite *FromClientAPITestSuite) TestProcessStreamNewStatus() {
ActivityStreamsType: ap.ObjectNote,
}
// put the status in the db first, to mimic what would have already happened earlier up the flow
err := suite.db.PutStatus(ctx, newStatus)
suite.NoError(err)
// Put the status in the db first, to mimic what
// would have already happened earlier up the flow.
if err := suite.db.PutStatus(ctx, newStatus); err != nil {
suite.FailNow(err.Error())
}
// process the new status
err = suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
// Process the new status.
if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityCreate,
GTSModel: newStatus,
OriginAccount: postingAccount,
})
suite.NoError(err)
}); err != nil {
suite.FailNow(err.Error())
}
// zork's stream should have the newly created status in it now
msg := <-wssStream.Messages
suite.Equal(stream.EventTypeUpdate, msg.Event)
suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
statusStreamed := &apimodel.Status{}
err = json.Unmarshal([]byte(msg.Payload), statusStreamed)
suite.NoError(err)
suite.Equal("01FN4B2F88TF9676DYNXWE1WSS", statusStreamed.ID)
suite.Equal("this status should stream :)", statusStreamed.Content)
// Check message in home stream.
homeMsg := <-homeStream.Messages
suite.Equal(stream.EventTypeUpdate, homeMsg.Event)
suite.EqualValues([]string{stream.TimelineHome}, homeMsg.Stream)
suite.Empty(homeStream.Messages) // Stream should now be empty.
// and stream should now be empty
suite.Empty(wssStream.Messages)
// Check status from home stream.
homeStreamStatus := &apimodel.Status{}
if err := json.Unmarshal([]byte(homeMsg.Payload), homeStreamStatus); err != nil {
suite.FailNow(err.Error())
}
suite.Equal(newStatus.ID, homeStreamStatus.ID)
suite.Equal(newStatus.Content, homeStreamStatus.Content)
// the irrelevant messages stream should also be empty
suite.Empty(irrelevantStream.Messages)
// Check message in list stream.
listMsg := <-listStream.Messages
suite.Equal(stream.EventTypeUpdate, listMsg.Event)
suite.EqualValues([]string{stream.TimelineList + ":" + testList.ID}, listMsg.Stream)
suite.Empty(listStream.Messages) // Stream should now be empty.
// Check status from list stream.
listStreamStatus := &apimodel.Status{}
if err := json.Unmarshal([]byte(listMsg.Payload), listStreamStatus); err != nil {
suite.FailNow(err.Error())
}
suite.Equal(newStatus.ID, listStreamStatus.ID)
suite.Equal(newStatus.Content, listStreamStatus.Content)
}
func (suite *FromClientAPITestSuite) TestProcessStatusDelete() {
ctx := context.Background()
var (
ctx = context.Background()
deletingAccount = suite.testAccounts["local_account_1"]
receivingAccount = suite.testAccounts["local_account_2"]
deletedStatus = suite.testStatuses["local_account_1_status_1"]
boostOfDeletedStatus = suite.testStatuses["admin_account_status_4"]
streams = suite.openStreams(ctx, receivingAccount, nil)
homeStream = streams[stream.TimelineHome]
)
deletingAccount := suite.testAccounts["local_account_1"]
receivingAccount := suite.testAccounts["local_account_2"]
// Delete the status from the db first, to mimic what
// would have already happened earlier up the flow
if err := suite.db.DeleteStatusByID(ctx, deletedStatus.ID); err != nil {
suite.FailNow(err.Error())
}
deletedStatus := suite.testStatuses["local_account_1_status_1"]
boostOfDeletedStatus := suite.testStatuses["admin_account_status_4"]
// open a home timeline stream for turtle, who follows zork
wssStream, errWithCode := suite.processor.Stream().Open(ctx, receivingAccount, stream.TimelineHome)
suite.NoError(errWithCode)
// delete the status from the db first, to mimic what would have already happened earlier up the flow
err := suite.db.DeleteStatusByID(ctx, deletedStatus.ID)
suite.NoError(err)
// process the status delete
err = suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
// Process the status delete.
if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
APObjectType: ap.ObjectNote,
APActivityType: ap.ActivityDelete,
GTSModel: deletedStatus,
OriginAccount: deletingAccount,
})
suite.NoError(err)
}); err != nil {
suite.FailNow(err.Error())
}
// turtle's stream should have the delete of admin's boost in it now
msg := <-wssStream.Messages
// Stream should have the delete of admin's boost in it now.
msg := <-homeStream.Messages
suite.Equal(stream.EventTypeDelete, msg.Event)
suite.Equal(boostOfDeletedStatus.ID, msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
// turtle's stream should also have the delete of the message itself in it
msg = <-wssStream.Messages
// Stream should also have the delete of the message itself in it.
msg = <-homeStream.Messages
suite.Equal(stream.EventTypeDelete, msg.Event)
suite.Equal(deletedStatus.ID, msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
// stream should now be empty
suite.Empty(wssStream.Messages)
// Stream should now be empty.
suite.Empty(homeStream.Messages)
// the boost should no longer be in the database
_, err = suite.db.GetStatusByID(ctx, boostOfDeletedStatus.ID)
suite.ErrorIs(err, db.ErrNoEntries)
// Boost should no longer be in the database.
if !testrig.WaitFor(func() bool {
_, err := suite.db.GetStatusByID(ctx, boostOfDeletedStatus.ID)
return errors.Is(err, db.ErrNoEntries)
}) {
suite.FailNow("timed out waiting for status delete")
}
}
func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() {
ctx := context.Background()
postingAccount := suite.testAccounts["admin_account"]
receivingAccount := suite.testAccounts["local_account_1"]
var (
ctx = context.Background()
postingAccount = suite.testAccounts["admin_account"]
receivingAccount = suite.testAccounts["local_account_1"]
streams = suite.openStreams(ctx, receivingAccount, nil)
notifStream = streams[stream.TimelineNotifications]
)
// Update the follow from receiving account -> posting account so
// that receiving account wants notifs when posting account posts.
@ -204,8 +225,9 @@ func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() {
// Put the status in the db first, to mimic what
// would have already happened earlier up the flow.
err := suite.db.PutStatus(ctx, newStatus)
suite.NoError(err)
if err := suite.db.PutStatus(ctx, newStatus); err != nil {
suite.FailNow(err.Error())
}
// Process the new status.
if err := suite.processor.ProcessFromClientAPI(ctx, messages.FromClientAPI{
@ -230,6 +252,19 @@ func (suite *FromClientAPITestSuite) TestProcessNewStatusWithNotification() {
}) {
suite.FailNow("timed out waiting for new status notification")
}
// Check message in notification stream.
notifMsg := <-notifStream.Messages
suite.Equal(stream.EventTypeNotification, notifMsg.Event)
suite.EqualValues([]string{stream.TimelineNotifications}, notifMsg.Stream)
suite.Empty(notifStream.Messages) // Stream should now be empty.
// Check notif.
notif := &apimodel.Notification{}
if err := json.Unmarshal([]byte(notifMsg.Payload), notif); err != nil {
suite.FailNow(err.Error())
}
suite.Equal(newStatus.ID, notif.Status.ID)
}
func TestFromClientAPITestSuite(t *testing.T) {

View file

@ -30,12 +30,14 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/stream"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
)
// timelineAndNotifyStatus processes the given new status and inserts it into
// the HOME timelines of accounts that follow the status author. It will also
// handle notifications for any mentions attached to the account, and also
// notifications for any local accounts that want a notif when this account posts.
// the HOME and LIST timelines of accounts that follow the status author.
//
// It will also handle notifications for any mentions attached to the account, and
// also notifications for any local accounts that want to know when this account posts.
func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmodel.Status) error {
// Ensure status fully populated; including account, mentions, etc.
if err := p.state.DB.PopulateStatus(ctx, status); err != nil {
@ -89,10 +91,43 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta
continue
}
// Add status to each list that this follow
// is included in, and stream it if applicable.
listEntries, err := p.state.DB.GetListEntriesForFollowID(
// We only need the list IDs.
gtscontext.SetBarebones(ctx),
follow.ID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err))
continue
}
for _, listEntry := range listEntries {
if _, err := p.timelineStatus(
ctx,
p.state.Timelines.List.IngestOne,
listEntry.ListID, // list timelines are keyed by list ID
follow.Account,
status,
stream.TimelineList+":"+listEntry.ListID, // key streamType to this specific list
); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err))
continue
}
}
// Add status to home timeline for this
// follower, and stream it if applicable.
if timelined, err := p.timelineStatusForAccount(ctx, follow.Account, status); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error timelining status: %w", err))
if timelined, err := p.timelineStatus(
ctx,
p.state.Timelines.Home.IngestOne,
follow.AccountID, // home timelines are keyed by account ID
follow.Account,
status,
stream.TimelineHome,
); err != nil {
errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error home timelining status: %w", err))
continue
} else if !timelined {
// Status wasn't added to home tomeline,
@ -133,13 +168,21 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta
return errs.Combine()
}
// timelineStatusForAccount puts the given status in the HOME timeline
// of the account with given accountID, if it's HomeTimelineable.
// timelineStatus uses the provided ingest function to put the given
// status in a timeline with the given ID, if it's timelineable.
//
// If the status was inserted into the home timeline of the given account,
// true will be returned + it will also be streamed via websockets to the user.
func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
// If the status was inserted into the timeline, true will be returned
// + it will also be streamed to the user using the given streamType.
func (p *Processor) timelineStatus(
ctx context.Context,
ingest func(context.Context, string, timeline.Timelineable) (bool, error),
timelineID string,
account *gtsmodel.Account,
status *gtsmodel.Status,
streamType string,
) (bool, error) {
// Make sure the status is timelineable.
// This works for both home and list timelines.
if timelineable, err := p.filter.StatusHomeTimelineable(ctx, account, status); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", account.ID, err)
return false, err
@ -148,8 +191,8 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmo
return false, nil
}
// Insert status in the home timeline of account.
if inserted, err := p.statusTimelines.IngestOne(ctx, account.ID, status); err != nil {
// Ingest status into given timeline using provided function.
if inserted, err := ingest(ctx, timelineID, status); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %w", status.ID, err)
return false, err
} else if !inserted {
@ -164,7 +207,7 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmo
return true, err
}
if err := p.stream.Update(apiStatus, account, stream.TimelineHome); err != nil {
if err := p.stream.Update(apiStatus, account, []string{streamType}); err != nil {
err = fmt.Errorf("timelineStatusForAccount: error streaming update for status %s: %w", status.ID, err)
return true, err
}
@ -401,7 +444,7 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
// deleteStatusFromTimelines completely removes the given status from all timelines.
// It will also stream deletion of the status to all open streams.
func (p *Processor) deleteStatusFromTimelines(ctx context.Context, status *gtsmodel.Status) error {
if err := p.statusTimelines.WipeItemFromAllTimelines(ctx, status.ID); err != nil {
if err := p.state.Timelines.Home.WipeItemFromAllTimelines(ctx, status.ID); err != nil {
return err
}

View file

@ -342,10 +342,10 @@ func (p *Processor) processCreateBlockFromFederator(ctx context.Context, federat
}
// remove any of the blocking account's statuses from the blocked account's timeline, and vice versa
if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.AccountID, block.TargetAccountID); err != nil {
return err
}
if err := p.statusTimelines.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
if err := p.state.Timelines.Home.WipeItemsFromAccountID(ctx, block.TargetAccountID, block.AccountID); err != nil {
return err
}
// TODO: same with notifications

View file

@ -0,0 +1,50 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
// Create creates one a new list for the given account, using the provided parameters.
// These params should have already been validated by the time they reach this function.
func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, title string, repliesPolicy gtsmodel.RepliesPolicy) (*apimodel.List, gtserror.WithCode) {
list := &gtsmodel.List{
ID: id.NewULID(),
Title: title,
AccountID: account.ID,
RepliesPolicy: repliesPolicy,
}
if err := p.state.DB.PutList(ctx, list); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
err = errors.New("you already have a list with this title")
return nil, gtserror.NewErrorConflict(err, err.Error())
}
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiList(ctx, list)
}

View file

@ -0,0 +1,46 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Delete deletes one list for the given account.
func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, id string) gtserror.WithCode {
list, errWithCode := p.getList(
// Use barebones ctx; no embedded
// structs necessary for this call.
gtscontext.SetBarebones(ctx),
account.ID,
id,
)
if errWithCode != nil {
return errWithCode
}
if err := p.state.DB.DeleteListByID(ctx, list.ID); err != nil {
return gtserror.NewErrorInternalError(err)
}
return nil
}

View file

@ -0,0 +1,155 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// Get returns the api model of one list with the given ID.
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.List, gtserror.WithCode) {
list, errWithCode := p.getList(
// Use barebones ctx; no embedded
// structs necessary for this call.
gtscontext.SetBarebones(ctx),
account.ID,
id,
)
if errWithCode != nil {
return nil, errWithCode
}
return p.apiList(ctx, list)
}
// GetMultiple returns multiple lists created by the given account, sorted by list ID DESC (newest first).
func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.List, gtserror.WithCode) {
lists, err := p.state.DB.GetListsForAccountID(
// Use barebones ctx; no embedded
// structs necessary for simple GET.
gtscontext.SetBarebones(ctx),
account.ID,
)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, nil
}
return nil, gtserror.NewErrorInternalError(err)
}
apiLists := make([]*apimodel.List, 0, len(lists))
for _, list := range lists {
apiList, errWithCode := p.apiList(ctx, list)
if errWithCode != nil {
return nil, errWithCode
}
apiLists = append(apiLists, apiList)
}
return apiLists, nil
}
// GetListAccounts returns accounts that are in the given list, owned by the given account.
// The additional parameters can be used for paging.
func (p *Processor) GetListAccounts(
ctx context.Context,
account *gtsmodel.Account,
listID string,
maxID string,
sinceID string,
minID string,
limit int,
) (*apimodel.PageableResponse, gtserror.WithCode) {
// Ensure list exists + is owned by requesting account.
if _, errWithCode := p.getList(ctx, account.ID, listID); errWithCode != nil {
return nil, errWithCode
}
// To know which accounts are in the list,
// we need to first get requested list entries.
listEntries, err := p.state.DB.GetListEntries(ctx, listID, maxID, sinceID, minID, limit)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("GetListAccounts: error getting list entries: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(listEntries)
if count == 0 {
// No list entries means no accounts.
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, count)
nextMaxIDValue string
prevMinIDValue string
)
// For each list entry, we want the account it points to.
// To get this, we need to first get the follow that the
// list entry pertains to, then extract the target account
// from that follow.
//
// We do paging not by account ID, but by list entry ID.
for i, listEntry := range listEntries {
if i == count-1 {
nextMaxIDValue = listEntry.ID
}
if i == 0 {
prevMinIDValue = listEntry.ID
}
if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
log.Debugf(ctx, "skipping list entry because of error populating it: %q", err)
continue
}
if err := p.state.DB.PopulateFollow(ctx, listEntry.Follow); err != nil {
log.Debugf(ctx, "skipping list entry because of error populating follow: %q", err)
continue
}
apiAccount, err := p.tc.AccountToAPIAccountPublic(ctx, listEntry.Follow.TargetAccount)
if err != nil {
log.Debugf(ctx, "skipping list entry because of error converting follow target account: %q", err)
continue
}
items[i] = apiAccount
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/lists/" + listID + "/accounts",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}

View file

@ -0,0 +1,35 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
)
type Processor struct {
state *state.State
tc typeutils.TypeConverter
}
func New(state *state.State, tc typeutils.TypeConverter) Processor {
return Processor{
state: state,
tc: tc,
}
}

View file

@ -0,0 +1,73 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Update updates one list for the given account, using the provided parameters.
// These params should have already been validated by the time they reach this function.
func (p *Processor) Update(
ctx context.Context,
account *gtsmodel.Account,
id string,
title *string,
repliesPolicy *gtsmodel.RepliesPolicy,
) (*apimodel.List, gtserror.WithCode) {
list, errWithCode := p.getList(
// Use barebones ctx; no embedded
// structs necessary for this call.
gtscontext.SetBarebones(ctx),
account.ID,
id,
)
if errWithCode != nil {
return nil, errWithCode
}
// Only update columns we're told to update.
columns := make([]string, 0, 2)
if title != nil {
list.Title = *title
columns = append(columns, "title")
}
if repliesPolicy != nil {
list.RepliesPolicy = *repliesPolicy
columns = append(columns, "replies_policy")
}
if err := p.state.DB.UpdateList(ctx, list, columns...); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
err = errors.New("you already have a list with this title")
return nil, gtserror.NewErrorConflict(err, err.Error())
}
return nil, gtserror.NewErrorInternalError(err)
}
return p.apiList(ctx, list)
}

View file

@ -0,0 +1,151 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
// AddToList adds targetAccountIDs to the given list, if valid.
func (p *Processor) AddToList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode {
// Ensure this list exists + account owns it.
list, errWithCode := p.getList(ctx, account.ID, listID)
if errWithCode != nil {
return errWithCode
}
// Pre-assemble list of entries to add. We *could* add these
// one by one as we iterate through accountIDs, but according
// to the Mastodon API we should only add them all once we know
// they're all valid, no partial updates.
listEntries := make([]*gtsmodel.ListEntry, 0, len(targetAccountIDs))
// Check each targetAccountID is valid.
// - Follow must exist.
// - Follow must not already be in the given list.
for _, targetAccountID := range targetAccountIDs {
// Ensure follow exists.
follow, err := p.state.DB.GetFollow(ctx, account.ID, targetAccountID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("you do not follow account %s", targetAccountID)
return gtserror.NewErrorNotFound(err, err.Error())
}
return gtserror.NewErrorInternalError(err)
}
// Ensure followID not already in list.
// This particular call to isInList will
// never error, so just check entryID.
entryID, _ := isInList(
list,
follow.ID,
func(listEntry *gtsmodel.ListEntry) (string, error) {
// Looking for the listEntry follow ID.
return listEntry.FollowID, nil
},
)
// Empty entryID means entry with given
// followID wasn't found in the list.
if entryID != "" {
err = fmt.Errorf("account with id %s is already in list %s with entryID %s", targetAccountID, listID, entryID)
return gtserror.NewErrorUnprocessableEntity(err, err.Error())
}
// Entry wasn't in the list, we can add it.
listEntries = append(listEntries, &gtsmodel.ListEntry{
ID: id.NewULID(),
ListID: listID,
FollowID: follow.ID,
})
}
// If we get to here we can assume all
// entries are valid, so try to add them.
if err := p.state.DB.PutListEntries(ctx, listEntries); err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
err = fmt.Errorf("one or more errors inserting list entries: %w", err)
return gtserror.NewErrorUnprocessableEntity(err, err.Error())
}
return gtserror.NewErrorInternalError(err)
}
return nil
}
// RemoveFromList removes targetAccountIDs from the given list, if valid.
func (p *Processor) RemoveFromList(ctx context.Context, account *gtsmodel.Account, listID string, targetAccountIDs []string) gtserror.WithCode {
// Ensure this list exists + account owns it.
list, errWithCode := p.getList(ctx, account.ID, listID)
if errWithCode != nil {
return errWithCode
}
// For each targetAccountID, we want to check if
// a follow with that targetAccountID is in the
// given list. If it is in there, we want to remove
// it from the list.
for _, targetAccountID := range targetAccountIDs {
// Check if targetAccountID is
// on a follow in the list.
entryID, err := isInList(
list,
targetAccountID,
func(listEntry *gtsmodel.ListEntry) (string, error) {
// We need the follow so populate this
// entry, if it's not already populated.
if err := p.state.DB.PopulateListEntry(ctx, listEntry); err != nil {
return "", err
}
// Looking for the list entry targetAccountID.
return listEntry.Follow.TargetAccountID, nil
},
)
// Error may be returned here if there was an issue
// populating the list entry. We only return on proper
// DB errors, we can just skip no entry errors.
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("error checking if targetAccountID %s was in list %s: %w", targetAccountID, listID, err)
return gtserror.NewErrorInternalError(err)
}
if entryID == "" {
// There was an errNoEntries or targetAccount
// wasn't in this list anyway, so we can skip it.
continue
}
// TargetAccount was in the list, remove the entry.
if err := p.state.DB.DeleteListEntry(ctx, entryID); err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("error removing list entry %s from list %s: %w", entryID, listID, err)
return gtserror.NewErrorInternalError(err)
}
}
return nil
}

View file

@ -0,0 +1,85 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package list
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// getList is a shortcut to get one list from the database and
// check that it's owned by the given accountID. Will return
// appropriate errors so caller doesn't need to bother.
func (p *Processor) getList(ctx context.Context, accountID string, listID string) (*gtsmodel.List, gtserror.WithCode) {
list, err := p.state.DB.GetListByID(ctx, listID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// List doesn't seem to exist.
return nil, gtserror.NewErrorNotFound(err)
}
// Real database error.
return nil, gtserror.NewErrorInternalError(err)
}
if list.AccountID != accountID {
err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, accountID)
return nil, gtserror.NewErrorNotFound(err)
}
return list, nil
}
// apiList is a shortcut to return the API version of the given
// list, or return an appropriate error if conversion fails.
func (p *Processor) apiList(ctx context.Context, list *gtsmodel.List) (*apimodel.List, gtserror.WithCode) {
apiList, err := p.tc.ListToAPIList(ctx, list)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting list to api: %w", err))
}
return apiList, nil
}
// isInList check if thisID is equal to the result of thatID
// for any entry in the given list.
//
// Will return the id of the listEntry if true, empty if false,
// or an error if the result of thatID returns an error.
func isInList(
list *gtsmodel.List,
thisID string,
getThatID func(listEntry *gtsmodel.ListEntry) (string, error),
) (string, error) {
for _, listEntry := range list.ListEntries {
thatID, err := getThatID(listEntry)
if err != nil {
return "", err
}
if thisID == thatID {
return listEntry.ID, nil
}
}
return "", nil
}

View file

@ -1,52 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package processing_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
)
type NotificationTestSuite struct {
ProcessingStandardTestSuite
}
// get a notification where someone has liked our status
func (suite *NotificationTestSuite) TestGetNotifications() {
receivingAccount := suite.testAccounts["local_account_1"]
notifsResponse, err := suite.processor.NotificationsGet(context.Background(), suite.testAutheds["local_account_1"], "", "", "", 10, nil)
suite.NoError(err)
suite.Len(notifsResponse.Items, 1)
notif, ok := notifsResponse.Items[0].(*apimodel.Notification)
if !ok {
panic("notif in response wasn't *apimodel.Notification")
}
suite.NotNil(notif.Status)
suite.NotNil(notif.Status)
suite.NotNil(notif.Status.Account)
suite.Equal(receivingAccount.ID, notif.Status.Account.ID)
suite.Equal(`<http://localhost:8080/api/v1/notifications?limit=10&max_id=01F8Q0ANPTWW10DAKTX7BRPBJP>; rel="next", <http://localhost:8080/api/v1/notifications?limit=10&min_id=01F8Q0ANPTWW10DAKTX7BRPBJP>; rel="prev"`, notifsResponse.LinkHeader)
}
func TestNotificationTestSuite(t *testing.T) {
suite.Run(t, &NotificationTestSuite{})
}

View file

@ -29,39 +29,41 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
"github.com/superseriousbusiness/gotosocial/internal/processing/admin"
"github.com/superseriousbusiness/gotosocial/internal/processing/fedi"
"github.com/superseriousbusiness/gotosocial/internal/processing/list"
"github.com/superseriousbusiness/gotosocial/internal/processing/media"
"github.com/superseriousbusiness/gotosocial/internal/processing/report"
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
"github.com/superseriousbusiness/gotosocial/internal/processing/stream"
"github.com/superseriousbusiness/gotosocial/internal/processing/timeline"
"github.com/superseriousbusiness/gotosocial/internal/processing/user"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
type Processor struct {
federator federation.Federator
tc typeutils.TypeConverter
oauthServer oauth.Server
mediaManager mm.Manager
statusTimelines timeline.Manager
state *state.State
emailSender email.Sender
filter *visibility.Filter
federator federation.Federator
tc typeutils.TypeConverter
oauthServer oauth.Server
mediaManager mm.Manager
state *state.State
emailSender email.Sender
filter *visibility.Filter
/*
SUB-PROCESSORS
*/
account account.Processor
admin admin.Processor
fedi fedi.Processor
media media.Processor
report report.Processor
status status.Processor
stream stream.Processor
user user.Processor
account account.Processor
admin admin.Processor
fedi fedi.Processor
list list.Processor
media media.Processor
report report.Processor
status status.Processor
stream stream.Processor
timeline timeline.Processor
user user.Processor
}
func (p *Processor) Account() *account.Processor {
@ -76,6 +78,10 @@ func (p *Processor) Fedi() *fedi.Processor {
return &p.fedi
}
func (p *Processor) List() *list.Processor {
return &p.list
}
func (p *Processor) Media() *media.Processor {
return &p.media
}
@ -92,6 +98,10 @@ func (p *Processor) Stream() *stream.Processor {
return &p.stream
}
func (p *Processor) Timeline() *timeline.Processor {
return &p.timeline
}
func (p *Processor) User() *user.Processor {
return &p.user
}
@ -114,23 +124,19 @@ func NewProcessor(
tc: tc,
oauthServer: oauthServer,
mediaManager: mediaManager,
statusTimelines: timeline.NewManager(
StatusGrabFunction(state.DB),
StatusFilterFunction(state.DB, filter),
StatusPrepareFunction(state.DB, tc),
StatusSkipInsertFunction(),
),
state: state,
filter: filter,
emailSender: emailSender,
state: state,
filter: filter,
emailSender: emailSender,
}
// sub processors
// Instantiate sub processors.
processor.account = account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc)
processor.admin = admin.New(state, tc, mediaManager, federator.TransportController(), emailSender)
processor.fedi = fedi.New(state, tc, federator, filter)
processor.list = list.New(state, tc)
processor.media = media.New(state, tc, mediaManager, federator.TransportController())
processor.report = report.New(state, tc)
processor.timeline = timeline.New(state, tc, filter)
processor.status = status.New(state, federator, tc, filter, parseMentionFunc)
processor.stream = stream.New(state, oauthServer)
processor.user = user.New(state, emailSender)
@ -161,13 +167,3 @@ func (p *Processor) EnqueueFederator(ctx context.Context, msgs ...messages.FromF
}
})
}
// Start starts the Processor.
func (p *Processor) Start() error {
return p.statusTimelines.Start()
}
// Stop stops the processor cleanly.
func (p *Processor) Stop() error {
return p.statusTimelines.Stop()
}

View file

@ -18,6 +18,8 @@
package processing_test
import (
"context"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/email"
@ -28,8 +30,10 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/stream"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@ -61,6 +65,7 @@ type ProcessingStandardTestSuite struct {
testAutheds map[string]*oauth.Auth
testBlocks map[string]*gtsmodel.Block
testActivities map[string]testrig.ActivityWithSignature
testLists map[string]*gtsmodel.List
processor *processing.Processor
}
@ -84,6 +89,7 @@ func (suite *ProcessingStandardTestSuite) SetupSuite() {
},
}
suite.testBlocks = testrig.NewTestBlocks()
suite.testLists = testrig.NewTestLists()
}
func (suite *ProcessingStandardTestSuite) SetupTest() {
@ -99,6 +105,13 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
suite.storage = testrig.NewInMemoryStorage()
suite.state.Storage = suite.storage
suite.typeconverter = testrig.NewTestTypeConverter(suite.db)
testrig.StartTimelines(
&suite.state,
visibility.NewFilter(&suite.state),
suite.typeconverter,
)
suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media")
suite.httpClient.TestRemotePeople = testrig.NewTestFediPeople()
suite.httpClient.TestRemoteStatuses = testrig.NewTestFediStatuses()
@ -115,16 +128,40 @@ func (suite *ProcessingStandardTestSuite) SetupTest() {
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
if err := suite.processor.Start(); err != nil {
panic(err)
}
}
func (suite *ProcessingStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
testrig.StandardStorageTeardown(suite.storage)
if err := suite.processor.Stop(); err != nil {
panic(err)
}
testrig.StopWorkers(&suite.state)
}
func (suite *ProcessingStandardTestSuite) openStreams(ctx context.Context, account *gtsmodel.Account, listIDs []string) map[string]*stream.Stream {
streams := make(map[string]*stream.Stream)
for _, streamType := range []string{
stream.TimelineHome,
stream.TimelinePublic,
stream.TimelineNotifications,
} {
stream, err := suite.processor.Stream().Open(ctx, account, streamType)
if err != nil {
suite.FailNow(err.Error())
}
streams[streamType] = stream
}
for _, listID := range listIDs {
streamType := stream.TimelineList + ":" + listID
stream, err := suite.processor.Stream().Open(ctx, account, streamType)
if err != nil {
suite.FailNow(err.Error())
}
streams[streamType] = stream
}
return streams
}

View file

@ -88,6 +88,12 @@ func (suite *StatusStandardTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager)
filter := visibility.NewFilter(&suite.state)
testrig.StartTimelines(
&suite.state,
filter,
testrig.NewTestTypeConverter(suite.db),
)
suite.status = status.New(&suite.state, suite.federator, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, suite.testAccounts)

View file

@ -1,309 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package processing
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
const boostReinsertionDepth = 50
// StatusGrabFunction returns a function that satisfies the GrabFunction interface in internal/timeline.
func StatusGrabFunction(database db.DB) timeline.GrabFunction {
return func(ctx context.Context, timelineAccountID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) {
statuses, err := database.GetHomeTimeline(ctx, timelineAccountID, maxID, sinceID, minID, limit, false)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, true, nil // we just don't have enough statuses left in the db so return stop = true
}
return nil, false, fmt.Errorf("statusGrabFunction: error getting statuses from db: %w", err)
}
items := make([]timeline.Timelineable, len(statuses))
for i, s := range statuses {
items[i] = s
}
return items, false, nil
}
}
// StatusFilterFunction returns a function that satisfies the FilterFunction interface in internal/timeline.
func StatusFilterFunction(database db.DB, filter *visibility.Filter) timeline.FilterFunction {
return func(ctx context.Context, timelineAccountID string, item timeline.Timelineable) (shouldIndex bool, err error) {
status, ok := item.(*gtsmodel.Status)
if !ok {
return false, errors.New("StatusFilterFunction: could not convert item to *gtsmodel.Status")
}
requestingAccount, err := database.GetAccountByID(ctx, timelineAccountID)
if err != nil {
return false, fmt.Errorf("StatusFilterFunction: error getting account with id %s: %w", timelineAccountID, err)
}
timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status)
if err != nil {
return false, fmt.Errorf("StatusFilterFunction: error checking hometimelineability of status %s for account %s: %w", status.ID, timelineAccountID, err)
}
return timelineable, nil
}
}
// StatusPrepareFunction returns a function that satisfies the PrepareFunction interface in internal/timeline.
func StatusPrepareFunction(database db.DB, tc typeutils.TypeConverter) timeline.PrepareFunction {
return func(ctx context.Context, timelineAccountID string, itemID string) (timeline.Preparable, error) {
status, err := database.GetStatusByID(ctx, itemID)
if err != nil {
return nil, fmt.Errorf("StatusPrepareFunction: error getting status with id %s: %w", itemID, err)
}
requestingAccount, err := database.GetAccountByID(ctx, timelineAccountID)
if err != nil {
return nil, fmt.Errorf("StatusPrepareFunction: error getting account with id %s: %w", timelineAccountID, err)
}
return tc.StatusToAPIStatus(ctx, status, requestingAccount)
}
}
// StatusSkipInsertFunction returns a function that satisifes the SkipInsertFunction interface in internal/timeline.
func StatusSkipInsertFunction() timeline.SkipInsertFunction {
return func(
ctx context.Context,
newItemID string,
newItemAccountID string,
newItemBoostOfID string,
newItemBoostOfAccountID string,
nextItemID string,
nextItemAccountID string,
nextItemBoostOfID string,
nextItemBoostOfAccountID string,
depth int,
) (bool, error) {
// make sure we don't insert a duplicate
if newItemID == nextItemID {
return true, nil
}
// check if it's a boost
if newItemBoostOfID != "" {
// skip if we've recently put another boost of this status in the timeline
if newItemBoostOfID == nextItemBoostOfID {
if depth < boostReinsertionDepth {
return true, nil
}
}
// skip if we've recently put the original status in the timeline
if newItemBoostOfID == nextItemID {
if depth < boostReinsertionDepth {
return true, nil
}
}
}
// insert the item
return false, nil
}
}
func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, err := p.statusTimelines.GetTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local)
if err != nil {
err = fmt.Errorf("HomeTimelineGet: error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, item := range statuses {
if i == count-1 {
nextMaxIDValue = item.GetID()
}
if i == 0 {
prevMinIDValue = item.GetID()
}
items[i] = item
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/home",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}
func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// No statuses (left) in public timeline.
return util.EmptyPageableResponse(), nil
}
// An actual error has occurred.
err = fmt.Errorf("PublicTimelineGet: db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, 0, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, s := range statuses {
// Set next + prev values before filtering and API
// converting, so caller can still page properly.
if i == count-1 {
nextMaxIDValue = s.ID
}
if i == 0 {
prevMinIDValue = s.ID
}
timelineable, err := p.filter.StatusPublicTimelineable(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking StatusPublicTimelineable: %s", s.ID, err)
continue
}
if !timelineable {
continue
}
apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account)
if err != nil {
log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err)
continue
}
items = append(items, apiStatus)
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/public",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}
func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// There are just no entries (left).
return util.EmptyPageableResponse(), nil
}
// An actual error has occurred.
err = fmt.Errorf("FavedTimelineGet: db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
filtered, err := p.filterFavedStatuses(ctx, authed, statuses)
if err != nil {
err = fmt.Errorf("FavedTimelineGet: error filtering statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
items := make([]interface{}, len(filtered))
for i, item := range filtered {
items[i] = item
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/favourites",
NextMaxIDValue: nextMaxID,
PrevMinIDValue: prevMinID,
Limit: limit,
})
}
func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
apiStatuses := make([]*apimodel.Status, 0, len(statuses))
for _, s := range statuses {
if _, err := p.state.DB.GetAccountByID(ctx, s.AccountID); err != nil {
if errors.Is(err, db.ErrNoEntries) {
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
}
err = fmt.Errorf("filterFavedStatuses: db error getting status author: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
timelineable, err := p.filter.StatusVisible(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
}
if !timelineable {
continue
}
apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account)
if err != nil {
log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err)
continue
}
apiStatuses = append(apiStatuses, apiStatus)
}
return apiStatuses, nil
}

View file

@ -31,60 +31,65 @@ import (
)
// Open returns a new Stream for the given account, which will contain a channel for passing messages back to the caller.
func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamTimeline string) (*stream.Stream, gtserror.WithCode) {
func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamType string) (*stream.Stream, gtserror.WithCode) {
l := log.WithContext(ctx).WithFields(kv.Fields{
{"account", account.ID},
{"streamType", streamTimeline},
{"streamType", streamType},
}...)
l.Debug("received open stream request")
// each stream needs a unique ID so we know to close it
streamID, err := id.NewRandomULID()
var (
streamID string
err error
)
// Each stream needs a unique ID so we know to close it.
streamID, err = id.NewRandomULID()
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %s", err))
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err))
}
// Each stream can be subscibed to multiple timelines.
// Each stream can be subscibed to multiple types.
// Record them in a set, and include the initial one
// if it was given to us
timelines := map[string]bool{}
if streamTimeline != "" {
timelines[streamTimeline] = true
// if it was given to us.
streamTypes := map[string]any{}
if streamType != "" {
streamTypes[streamType] = true
}
thisStream := &stream.Stream{
ID: streamID,
Timelines: timelines,
Messages: make(chan *stream.Message, 100),
Hangup: make(chan interface{}, 1),
Connected: true,
newStream := &stream.Stream{
ID: streamID,
StreamTypes: streamTypes,
Messages: make(chan *stream.Message, 100),
Hangup: make(chan interface{}, 1),
Connected: true,
}
go p.waitToCloseStream(account, thisStream)
go p.waitToCloseStream(account, newStream)
v, ok := p.streamMap.Load(account.ID)
if !ok || v == nil {
// there is no entry in the streamMap for this account yet, so make one and store it
streamsForAccount := &stream.StreamsForAccount{
Streams: []*stream.Stream{
thisStream,
},
}
p.streamMap.Store(account.ID, streamsForAccount)
} else {
// there is an entry in the streamMap for this account
// parse the interface as a streamsForAccount
if ok {
// There is an entry in the streamMap
// for this account. Parse it out.
streamsForAccount, ok := v.(*stream.StreamsForAccount)
if !ok {
return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
}
// append this stream to it
// Append new stream to existing entry.
streamsForAccount.Lock()
streamsForAccount.Streams = append(streamsForAccount.Streams, thisStream)
streamsForAccount.Streams = append(streamsForAccount.Streams, newStream)
streamsForAccount.Unlock()
} else {
// There is no entry in the streamMap for
// this account yet. Create one and store it.
p.streamMap.Store(account.ID, &stream.StreamsForAccount{
Streams: []*stream.Stream{
newStream,
},
})
}
return thisStream, nil
return newStream, nil
}
// waitToCloseStream waits until the hangup channel is closed for the given stream.

View file

@ -18,7 +18,6 @@
package stream
import (
"errors"
"sync"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
@ -40,37 +39,38 @@ func New(state *state.State, oauthServer oauth.Server) Processor {
}
// toAccount streams the given payload with the given event type to any streams currently open for the given account ID.
func (p *Processor) toAccount(payload string, event string, timelines []string, accountID string) error {
func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error {
// Load all streams open for this account.
v, ok := p.streamMap.Load(accountID)
if !ok {
// no open connections so nothing to stream
return nil
}
streamsForAccount, ok := v.(*stream.StreamsForAccount)
if !ok {
return errors.New("stream map error")
return nil // No entry = nothing to stream.
}
streamsForAccount := v.(*stream.StreamsForAccount) //nolint:forcetypeassert
streamsForAccount.Lock()
defer streamsForAccount.Unlock()
for _, s := range streamsForAccount.Streams {
s.Lock()
defer s.Unlock()
if !s.Connected {
continue
}
for _, t := range timelines {
if _, found := s.Timelines[t]; found {
typeLoop:
for _, streamType := range streamTypes {
if _, found := s.StreamTypes[streamType]; found {
s.Messages <- &stream.Message{
Stream: []string{string(t)},
Stream: []string{streamType},
Event: string(event),
Payload: payload,
}
// break out to the outer loop, to avoid sending duplicates
// of the same event to the same stream
break
// Break out to the outer loop,
// to avoid sending duplicates of
// the same event to the same stream.
break typeLoop
}
}
}

View file

@ -27,11 +27,11 @@ import (
)
// Update streams the given update to any open, appropriate streams belonging to the given account.
func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, timeline string) error {
func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
bytes, err := json.Marshal(s)
if err != nil {
return fmt.Errorf("error marshalling status to json: %s", err)
}
return p.toAccount(string(bytes), stream.EventTypeUpdate, []string{timeline}, account.ID)
return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID)
}

View file

@ -0,0 +1,71 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
)
// SkipInsert returns a function that satisifes SkipInsertFunction.
func SkipInsert() timeline.SkipInsertFunction {
// Gap to allow between a status or boost of status,
// and reinsertion of a new boost of that status.
// This is useful to avoid a heavily boosted status
// showing up way too often in a user's timeline.
const boostReinsertionDepth = 50
return func(
ctx context.Context,
newItemID string,
newItemAccountID string,
newItemBoostOfID string,
newItemBoostOfAccountID string,
nextItemID string,
nextItemAccountID string,
nextItemBoostOfID string,
nextItemBoostOfAccountID string,
depth int,
) (bool, error) {
if newItemID == nextItemID {
// Don't insert duplicates.
return true, nil
}
if newItemBoostOfID != "" {
if newItemBoostOfID == nextItemBoostOfID &&
depth < boostReinsertionDepth {
// Don't insert boosts of items
// we've seen boosted recently.
return true, nil
}
if newItemBoostOfID == nextItemID &&
depth < boostReinsertionDepth {
// Don't insert boosts of items when
// we've seen the original recently.
return true, nil
}
}
// Proceed with insertion
// (that's what she said!).
return false, nil
}
}

View file

@ -0,0 +1,73 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FavedTimelineGet: db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
items := make([]interface{}, 0, count)
for _, s := range statuses {
visible, err := p.filter.StatusVisible(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
}
if !visible {
continue
}
apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account)
if err != nil {
log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err)
continue
}
items = append(items, apiStatus)
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/favourites",
NextMaxIDValue: nextMaxID,
PrevMinIDValue: prevMinID,
Limit: limit,
})
}

View file

@ -0,0 +1,133 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
// HomeTimelineGrab returns a function that satisfies GrabFunction for home timelines.
func HomeTimelineGrab(state *state.State) timeline.GrabFunction {
return func(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) {
statuses, err := state.DB.GetHomeTimeline(ctx, accountID, maxID, sinceID, minID, limit, false)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, true, nil // we just don't have enough statuses left in the db so return stop = true
}
return nil, false, fmt.Errorf("HomeTimelineGrab: error getting statuses from db: %w", err)
}
items := make([]timeline.Timelineable, len(statuses))
for i, s := range statuses {
items[i] = s
}
return items, false, nil
}
}
// HomeTimelineFilter returns a function that satisfies FilterFunction for home timelines.
func HomeTimelineFilter(state *state.State, filter *visibility.Filter) timeline.FilterFunction {
return func(ctx context.Context, accountID string, item timeline.Timelineable) (shouldIndex bool, err error) {
status, ok := item.(*gtsmodel.Status)
if !ok {
return false, errors.New("HomeTimelineFilter: could not convert item to *gtsmodel.Status")
}
requestingAccount, err := state.DB.GetAccountByID(ctx, accountID)
if err != nil {
return false, fmt.Errorf("HomeTimelineFilter: error getting account with id %s: %w", accountID, err)
}
timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status)
if err != nil {
return false, fmt.Errorf("HomeTimelineFilter: error checking hometimelineability of status %s for account %s: %w", status.ID, accountID, err)
}
return timelineable, nil
}
}
// HomeTimelineStatusPrepare returns a function that satisfies PrepareFunction for home timelines.
func HomeTimelineStatusPrepare(state *state.State, tc typeutils.TypeConverter) timeline.PrepareFunction {
return func(ctx context.Context, accountID string, itemID string) (timeline.Preparable, error) {
status, err := state.DB.GetStatusByID(ctx, itemID)
if err != nil {
return nil, fmt.Errorf("StatusPrepare: error getting status with id %s: %w", itemID, err)
}
requestingAccount, err := state.DB.GetAccountByID(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("StatusPrepare: error getting account with id %s: %w", accountID, err)
}
return tc.StatusToAPIStatus(ctx, status, requestingAccount)
}
}
func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, err := p.state.Timelines.Home.GetTimeline(ctx, authed.Account.ID, maxID, sinceID, minID, limit, local)
if err != nil {
err = fmt.Errorf("HomeTimelineGet: error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, item := range statuses {
if i == count-1 {
nextMaxIDValue = item.GetID()
}
if i == 0 {
prevMinIDValue = item.GetID()
}
items[i] = item
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/home",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}

View file

@ -0,0 +1,157 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
// ListTimelineGrab returns a function that satisfies GrabFunction for list timelines.
func ListTimelineGrab(state *state.State) timeline.GrabFunction {
return func(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]timeline.Timelineable, bool, error) {
statuses, err := state.DB.GetListTimeline(ctx, listID, maxID, sinceID, minID, limit)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, true, nil // we just don't have enough statuses left in the db so return stop = true
}
return nil, false, fmt.Errorf("ListTimelineGrab: error getting statuses from db: %w", err)
}
items := make([]timeline.Timelineable, len(statuses))
for i, s := range statuses {
items[i] = s
}
return items, false, nil
}
}
// HomeTimelineFilter returns a function that satisfies FilterFunction for list timelines.
func ListTimelineFilter(state *state.State, filter *visibility.Filter) timeline.FilterFunction {
return func(ctx context.Context, listID string, item timeline.Timelineable) (shouldIndex bool, err error) {
status, ok := item.(*gtsmodel.Status)
if !ok {
return false, errors.New("ListTimelineFilter: could not convert item to *gtsmodel.Status")
}
list, err := state.DB.GetListByID(ctx, listID)
if err != nil {
return false, fmt.Errorf("ListTimelineFilter: error getting list with id %s: %w", listID, err)
}
requestingAccount, err := state.DB.GetAccountByID(ctx, list.AccountID)
if err != nil {
return false, fmt.Errorf("ListTimelineFilter: error getting account with id %s: %w", list.AccountID, err)
}
timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status)
if err != nil {
return false, fmt.Errorf("ListTimelineFilter: error checking hometimelineability of status %s for account %s: %w", status.ID, list.AccountID, err)
}
return timelineable, nil
}
}
// ListTimelineStatusPrepare returns a function that satisfies PrepareFunction for list timelines.
func ListTimelineStatusPrepare(state *state.State, tc typeutils.TypeConverter) timeline.PrepareFunction {
return func(ctx context.Context, listID string, itemID string) (timeline.Preparable, error) {
status, err := state.DB.GetStatusByID(ctx, itemID)
if err != nil {
return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting status with id %s: %w", itemID, err)
}
list, err := state.DB.GetListByID(ctx, listID)
if err != nil {
return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting list with id %s: %w", listID, err)
}
requestingAccount, err := state.DB.GetAccountByID(ctx, list.AccountID)
if err != nil {
return nil, fmt.Errorf("ListTimelineStatusPrepare: error getting account with id %s: %w", list.AccountID, err)
}
return tc.StatusToAPIStatus(ctx, status, requestingAccount)
}
}
func (p *Processor) ListTimelineGet(ctx context.Context, authed *oauth.Auth, listID string, maxID string, sinceID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) {
// Ensure list exists + is owned by this account.
list, err := p.state.DB.GetListByID(ctx, listID)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorNotFound(err)
}
return nil, gtserror.NewErrorInternalError(err)
}
if list.AccountID != authed.Account.ID {
err = fmt.Errorf("list with id %s does not belong to account %s", list.ID, authed.Account.ID)
return nil, gtserror.NewErrorNotFound(err)
}
statuses, err := p.state.Timelines.List.GetTimeline(ctx, listID, maxID, sinceID, minID, limit, false)
if err != nil {
err = fmt.Errorf("ListTimelineGet: error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, item := range statuses {
if i == count-1 {
nextMaxIDValue = item.GetID()
}
if i == 0 {
prevMinIDValue = item.GetID()
}
items[i] = item
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/list/" + listID,
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}

View file

@ -15,7 +15,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package processing
package timeline
import (
"context"
@ -33,12 +33,7 @@ import (
func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, excludeTypes []string) (*apimodel.PageableResponse, gtserror.WithCode) {
notifs, err := p.state.DB.GetAccountNotifications(ctx, authed.Account.ID, maxID, sinceID, minID, limit, excludeTypes)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// No notifs (left).
return util.EmptyPageableResponse(), nil
}
// An actual error has occurred.
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("NotificationsGet: db error getting notifications: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
@ -73,6 +68,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ma
log.Debugf(ctx, "skipping notification %s because of an error checking notification visibility: %s", n.ID, err)
continue
}
if !visible {
continue
}
@ -85,6 +81,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ma
log.Debugf(ctx, "skipping notification %s because of an error checking notification visibility: %s", n.ID, err)
continue
}
if !visible {
continue
}

View file

@ -0,0 +1,88 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"context"
"errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("PublicTimelineGet: db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
count := len(statuses)
if count == 0 {
return util.EmptyPageableResponse(), nil
}
var (
items = make([]interface{}, 0, count)
nextMaxIDValue string
prevMinIDValue string
)
for i, s := range statuses {
// Set next + prev values before filtering and API
// converting, so caller can still page properly.
if i == count-1 {
nextMaxIDValue = s.ID
}
if i == 0 {
prevMinIDValue = s.ID
}
timelineable, err := p.filter.StatusPublicTimelineable(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking StatusPublicTimelineable: %s", s.ID, err)
continue
}
if !timelineable {
continue
}
apiStatus, err := p.tc.StatusToAPIStatus(ctx, s, authed.Account)
if err != nil {
log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err)
continue
}
items = append(items, apiStatus)
}
return util.PackagePageableResponse(util.PageableResponseParams{
Items: items,
Path: "api/v1/timelines/public",
NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue,
Limit: limit,
})
}

View file

@ -0,0 +1,38 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package timeline
import (
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
)
type Processor struct {
state *state.State
tc typeutils.TypeConverter
filter *visibility.Filter
}
func New(state *state.State, tc typeutils.TypeConverter, filter *visibility.Filter) Processor {
return Processor{
state: state,
tc: tc,
filter: filter,
}
}

Some files were not shown because too many files have changed in this diff Show more