Weight differently

This commit is contained in:
Erik Johnston 2017-05-31 14:29:32 +01:00
parent 535c99f157
commit 293ef29655
2 changed files with 31 additions and 10 deletions

View file

@ -19,6 +19,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
class UserDirectoryStore(SQLBaseStore):
@ -50,26 +51,39 @@ class UserDirectoryStore(SQLBaseStore):
sql = """
INSERT INTO user_directory
(user_id, room_id, display_name, avatar_url, vector)
VALUES (?,?,?,?,to_tsvector('english', ?))
VALUES (?,?,?,?,
setweight(to_tsvector('english', ?), 'A')
|| to_tsvector('english', ?)
|| to_tsvector('english', COALESCE(?, ''))
)
"""
args = (
(
user_id, room_id, p.display_name, p.avatar_url,
get_localpart_from_id(user_id), get_domain_from_id(user_id),
p.display_name,
)
for user_id, p in users_with_profile.iteritems()
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
INSERT INTO user_directory
(user_id, room_id, display_name, avatar_url, value)
VALUES (?,?,?,?,?)
"""
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
def _add_profiles_to_user_dir_txn(txn):
txn.executemany(sql, (
args = (
(
user_id, room_id, p.display_name, p.avatar_url,
"%s %s" % (user_id, p.display_name,) if p.display_name else user_id
)
for user_id, p in users_with_profile.iteritems()
))
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
def _add_profiles_to_user_dir_txn(txn):
txn.executemany(sql, args)
for user_id in users_with_profile:
txn.call_after(
self.get_user_in_directory.invalidate, (user_id,)
@ -160,8 +174,8 @@ class UserDirectoryStore(SQLBaseStore):
sql = """
SELECT user_id, display_name, avatar_url
FROM user_directory
WHERE vector @@ to_tsquery('english', ?)
ORDER BY ts_rank_cd(vector, to_tsquery('english', ?)) DESC
WHERE vector @@ plainto_tsquery('english', ?)
ORDER BY ts_rank_cd(vector, plainto_tsquery('english', ?)) DESC
LIMIT ?
"""
args = (search_term, search_term, limit + 1,)

View file

@ -62,6 +62,13 @@ def get_domain_from_id(string):
return string[idx + 1:]
def get_localpart_from_id(string):
idx = string.find(":")
if idx == -1:
raise SynapseError(400, "Invalid ID: %r" % (string,))
return string[1:idx]
class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain"))
):