mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-11-22 01:05:53 +03:00
[bugfix] fix higher-level explicit domain rules causing issues with lower-level domain blocking (#2513)
* fix the sort direction of domain cache child nodes ...
* add more domain cache test cases
* add specific test for this bug to database domain test suite (thanks for writing this @tsmethurst!)
* remove unused field (this was a previous attempt at a fix)
* remove debugging println statements 😇
This commit is contained in:
parent
d5c305dc6e
commit
ccecf5a7e4
3 changed files with 114 additions and 19 deletions
47
internal/cache/domain/domain.go
vendored
47
internal/cache/domain/domain.go
vendored
|
@ -19,11 +19,10 @@ package domain
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Cache provides a means of caching domains in memory to reduce
|
// Cache provides a means of caching domains in memory to reduce
|
||||||
|
@ -58,6 +57,24 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err
|
||||||
return false, fmt.Errorf("error reloading cache: %w", err)
|
return false, fmt.Errorf("error reloading cache: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure the domains being inserted into the cache
|
||||||
|
// are sorted by number of domain parts. i.e. those
|
||||||
|
// with less parts are inserted last, else this can
|
||||||
|
// allow domains to fall through the matching code!
|
||||||
|
slices.SortFunc(domains, func(a, b string) int {
|
||||||
|
const k = +1
|
||||||
|
an := strings.Count(a, ".")
|
||||||
|
bn := strings.Count(b, ".")
|
||||||
|
switch {
|
||||||
|
case an < bn:
|
||||||
|
return +k
|
||||||
|
case an > bn:
|
||||||
|
return -k
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Allocate new radix trie
|
// Allocate new radix trie
|
||||||
// node to store matches.
|
// node to store matches.
|
||||||
root := new(root)
|
root := new(root)
|
||||||
|
@ -98,13 +115,13 @@ type root struct{ root node }
|
||||||
|
|
||||||
// Add will add the given domain to the radix trie.
|
// Add will add the given domain to the radix trie.
|
||||||
func (r *root) Add(domain string) {
|
func (r *root) Add(domain string) {
|
||||||
r.root.add(strings.Split(domain, "."))
|
r.root.Add(strings.Split(domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Match will return whether the given domain matches
|
// Match will return whether the given domain matches
|
||||||
// an existing stored domain in this radix trie.
|
// an existing stored domain in this radix trie.
|
||||||
func (r *root) Match(domain string) bool {
|
func (r *root) Match(domain string) bool {
|
||||||
return r.root.match(strings.Split(domain, "."))
|
return r.root.Match(strings.Split(domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort will sort the entire radix trie ensuring that
|
// Sort will sort the entire radix trie ensuring that
|
||||||
|
@ -118,7 +135,7 @@ func (r *root) Sort() {
|
||||||
// String returns a string representation of node (and its descendants).
|
// String returns a string representation of node (and its descendants).
|
||||||
func (r *root) String() string {
|
func (r *root) String() string {
|
||||||
buf := new(strings.Builder)
|
buf := new(strings.Builder)
|
||||||
r.root.writestr(buf, "")
|
r.root.WriteStr(buf, "")
|
||||||
return buf.String()
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,7 +144,7 @@ type node struct {
|
||||||
child []*node
|
child []*node
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *node) add(parts []string) {
|
func (n *node) Add(parts []string) {
|
||||||
if len(parts) == 0 {
|
if len(parts) == 0 {
|
||||||
panic("invalid domain")
|
panic("invalid domain")
|
||||||
}
|
}
|
||||||
|
@ -169,7 +186,7 @@ func (n *node) add(parts []string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *node) match(parts []string) bool {
|
func (n *node) Match(parts []string) bool {
|
||||||
for len(parts) > 0 {
|
for len(parts) > 0 {
|
||||||
// Pop next domain part.
|
// Pop next domain part.
|
||||||
i := len(parts) - 1
|
i := len(parts) - 1
|
||||||
|
@ -230,8 +247,16 @@ func (n *node) getChild(part string) *node {
|
||||||
|
|
||||||
func (n *node) sort() {
|
func (n *node) sort() {
|
||||||
// Sort this node's slice of child nodes.
|
// Sort this node's slice of child nodes.
|
||||||
slices.SortFunc(n.child, func(i, j *node) bool {
|
slices.SortFunc(n.child, func(i, j *node) int {
|
||||||
return i.part < j.part
|
const k = -1
|
||||||
|
switch {
|
||||||
|
case i.part < j.part:
|
||||||
|
return +k
|
||||||
|
case i.part > j.part:
|
||||||
|
return -k
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Sort each child node's children.
|
// Sort each child node's children.
|
||||||
|
@ -240,7 +265,7 @@ func (n *node) sort() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *node) writestr(buf *strings.Builder, prefix string) {
|
func (n *node) WriteStr(buf *strings.Builder, prefix string) {
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
// Suffix joining '.'
|
// Suffix joining '.'
|
||||||
prefix += "."
|
prefix += "."
|
||||||
|
@ -255,6 +280,6 @@ func (n *node) writestr(buf *strings.Builder, prefix string) {
|
||||||
|
|
||||||
// Iterate through node children.
|
// Iterate through node children.
|
||||||
for _, child := range n.child {
|
for _, child := range n.child {
|
||||||
child.writestr(buf, prefix)
|
child.WriteStr(buf, prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
24
internal/cache/domain/domain_test.go
vendored
24
internal/cache/domain/domain_test.go
vendored
|
@ -28,9 +28,13 @@ func TestCache(t *testing.T) {
|
||||||
c := new(domain.Cache)
|
c := new(domain.Cache)
|
||||||
|
|
||||||
cachedDomains := []string{
|
cachedDomains := []string{
|
||||||
"google.com",
|
"google.com", //
|
||||||
"google.co.uk",
|
"mail.google.com", // should be ignored since covered above
|
||||||
"pleroma.bad.host",
|
"dev.mail.google.com", // same again
|
||||||
|
"google.co.uk", //
|
||||||
|
"mail.google.co.uk", //
|
||||||
|
"pleroma.bad.host", //
|
||||||
|
"pleroma.still.a.bad.host", //
|
||||||
}
|
}
|
||||||
|
|
||||||
loader := func() ([]string, error) {
|
loader := func() ([]string, error) {
|
||||||
|
@ -38,22 +42,25 @@ func TestCache(t *testing.T) {
|
||||||
return cachedDomains, nil
|
return cachedDomains, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check a list of known cached domains.
|
// Check a list of known matching domains.
|
||||||
for _, domain := range []string{
|
for _, domain := range []string{
|
||||||
"google.com",
|
"google.com",
|
||||||
"mail.google.com",
|
"mail.google.com",
|
||||||
|
"dev.mail.google.com",
|
||||||
"google.co.uk",
|
"google.co.uk",
|
||||||
"mail.google.co.uk",
|
"mail.google.co.uk",
|
||||||
"pleroma.bad.host",
|
"pleroma.bad.host",
|
||||||
"dev.pleroma.bad.host",
|
"dev.pleroma.bad.host",
|
||||||
|
"pleroma.still.a.bad.host",
|
||||||
|
"dev.pleroma.still.a.bad.host",
|
||||||
} {
|
} {
|
||||||
t.Logf("checking domain matches: %s", domain)
|
t.Logf("checking domain matches: %s", domain)
|
||||||
if b, _ := c.Matches(domain, loader); !b {
|
if b, _ := c.Matches(domain, loader); !b {
|
||||||
t.Errorf("domain should be matched: %s", domain)
|
t.Fatalf("domain should be matched: %s", domain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check a list of known uncached domains.
|
// Check a list of known unmatched domains.
|
||||||
for _, domain := range []string{
|
for _, domain := range []string{
|
||||||
"askjeeves.com",
|
"askjeeves.com",
|
||||||
"ask-kim.co.uk",
|
"ask-kim.co.uk",
|
||||||
|
@ -61,10 +68,11 @@ func TestCache(t *testing.T) {
|
||||||
"mail.google.ie",
|
"mail.google.ie",
|
||||||
"gts.bad.host",
|
"gts.bad.host",
|
||||||
"mastodon.bad.host",
|
"mastodon.bad.host",
|
||||||
|
"akkoma.still.a.bad.host",
|
||||||
} {
|
} {
|
||||||
t.Logf("checking domain isn't matched: %s", domain)
|
t.Logf("checking domain isn't matched: %s", domain)
|
||||||
if b, _ := c.Matches(domain, loader); b {
|
if b, _ := c.Matches(domain, loader); b {
|
||||||
t.Errorf("domain should not be matched: %s", domain)
|
t.Fatalf("domain should not be matched: %s", domain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,6 +88,6 @@ func TestCache(t *testing.T) {
|
||||||
t.Log("load: returning known error")
|
t.Log("load: returning known error")
|
||||||
return nil, knownErr
|
return nil, knownErr
|
||||||
}); !errors.Is(err, knownErr) {
|
}); !errors.Is(err, knownErr) {
|
||||||
t.Errorf("matches did not return expected error: %v", err)
|
t.Fatalf("matches did not return expected error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package bundb_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -212,6 +213,67 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() {
|
||||||
suite.True(blocked)
|
suite.True(blocked)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *DomainTestSuite) TestIsOtherDomainBlockedWildcardAndExplicit() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
blocks := []*gtsmodel.DomainBlock{
|
||||||
|
{
|
||||||
|
ID: "01G204214Y9TNJEBX39C7G88SW",
|
||||||
|
Domain: "bad.apples",
|
||||||
|
CreatedByAccountID: suite.testAccounts["admin_account"].ID,
|
||||||
|
CreatedByAccount: suite.testAccounts["admin_account"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "01HKPSVQ864FQ2JJ01CDGPHHMJ",
|
||||||
|
Domain: "some.bad.apples",
|
||||||
|
CreatedByAccountID: suite.testAccounts["admin_account"].ID,
|
||||||
|
CreatedByAccount: suite.testAccounts["admin_account"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, block := range blocks {
|
||||||
|
if err := suite.db.CreateDomainBlock(ctx, block); err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure each block created
|
||||||
|
// above is now present in the db.
|
||||||
|
dbBlocks, err := suite.db.GetDomainBlocks(ctx)
|
||||||
|
if err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, block := range blocks {
|
||||||
|
if !slices.ContainsFunc(
|
||||||
|
dbBlocks,
|
||||||
|
func(dbBlock *gtsmodel.DomainBlock) bool {
|
||||||
|
return block.Domain == dbBlock.Domain
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
suite.FailNow("", "stored blocks did not contain %s", block.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// All domains and subdomains
|
||||||
|
// should now be blocked, even
|
||||||
|
// ones without an explicit block.
|
||||||
|
for _, domain := range []string{
|
||||||
|
"bad.apples",
|
||||||
|
"some.bad.apples",
|
||||||
|
"other.bad.apples",
|
||||||
|
} {
|
||||||
|
blocked, err := suite.db.IsDomainBlocked(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !blocked {
|
||||||
|
suite.Fail("", "domain %s should be blocked", domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDomainTestSuite(t *testing.T) {
|
func TestDomainTestSuite(t *testing.T) {
|
||||||
suite.Run(t, new(DomainTestSuite))
|
suite.Run(t, new(DomainTestSuite))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue