Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2024-04-18 13:34:28 +01:00
commit 31ac8b745c
39 changed files with 646 additions and 206 deletions

View file

@ -81,7 +81,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- uses: matrix-org/setup-python-poetry@v1 - uses: matrix-org/setup-python-poetry@v1
with: with:
@ -148,7 +148,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- name: Setup Poetry - name: Setup Poetry
@ -208,7 +208,7 @@ jobs:
with: with:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- uses: matrix-org/setup-python-poetry@v1 - uses: matrix-org/setup-python-poetry@v1
with: with:
@ -225,7 +225,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
with: with:
components: clippy components: clippy
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
@ -344,7 +344,7 @@ jobs:
postgres:${{ matrix.job.postgres-version }} postgres:${{ matrix.job.postgres-version }}
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- uses: matrix-org/setup-python-poetry@v1 - uses: matrix-org/setup-python-poetry@v1
@ -386,7 +386,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
# There aren't wheels for some of the older deps, so we need to install # There aren't wheels for some of the older deps, so we need to install
@ -498,7 +498,7 @@ jobs:
run: cat sytest-blacklist .ci/worker-blacklist > synapse-blacklist-with-workers run: cat sytest-blacklist .ci/worker-blacklist > synapse-blacklist-with-workers
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- name: Run SyTest - name: Run SyTest
@ -642,7 +642,7 @@ jobs:
path: synapse path: synapse
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- name: Prepare Complement's Prerequisites - name: Prepare Complement's Prerequisites
@ -674,7 +674,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Install Rust - name: Install Rust
uses: dtolnay/rust-toolchain@1.65.0 uses: dtolnay/rust-toolchain@1.66.0
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- run: cargo test - run: cargo test

View file

@ -1,3 +1,10 @@
# Synapse 1.105.0 (2024-04-16)
No significant changes since 1.105.0rc1.
# Synapse 1.105.0rc1 (2024-04-11) # Synapse 1.105.0rc1 (2024-04-11)
### Features ### Features

92
Cargo.lock generated
View file

@ -29,6 +29,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -53,12 +59,27 @@ dependencies = [
"generic-array", "generic-array",
] ]
[[package]]
name = "bytes"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
version = "1.0.0" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cpufeatures"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "crypto-common" name = "crypto-common"
version = "0.1.6" version = "0.1.6"
@ -80,6 +101,12 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.6" version = "0.14.6"
@ -90,6 +117,30 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "headers"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9"
dependencies = [
"base64",
"bytes",
"headers-core",
"http",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
dependencies = [
"http",
]
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.4.1" version = "0.4.1"
@ -102,6 +153,23 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "http"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]] [[package]]
name = "indoc" name = "indoc"
version = "2.0.4" version = "2.0.4"
@ -122,9 +190,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.135" version = "0.2.153"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
[[package]] [[package]]
name = "lock_api" name = "lock_api"
@ -157,6 +225,12 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "mime"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.15.0" version = "1.15.0"
@ -376,6 +450,17 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "sha1"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.10.0" version = "1.10.0"
@ -405,7 +490,10 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"blake2", "blake2",
"bytes",
"headers",
"hex", "hex",
"http",
"lazy_static", "lazy_static",
"log", "log",
"pyo3", "pyo3",

1
changelog.d/16920.bugfix Normal file
View file

@ -0,0 +1 @@
Adds validation to ensure that the `limit` parameter on `/publicRooms` is non-negative.

1
changelog.d/16943.bugfix Normal file
View file

@ -0,0 +1 @@
Make the CSAPI endpoint `/keys/device_signing/upload` idempotent.

1
changelog.d/17032.misc Normal file
View file

@ -0,0 +1 @@
Use new receipts column to optimise receipt and push action SQL queries. Contributed by Nick @ Beeper (@fizzadar).

1
changelog.d/17036.misc Normal file
View file

@ -0,0 +1 @@
Fix mypy with latest Twisted release.

1
changelog.d/17079.misc Normal file
View file

@ -0,0 +1 @@
Bump minimum supported Rust version to 1.66.0.

1
changelog.d/17081.misc Normal file
View file

@ -0,0 +1 @@
Add helpers to transform Twisted requests to Rust http Requests/Responses.

View file

@ -0,0 +1 @@
Support delegating the rendezvous mechanism described MSC4108 to an external implementation.

1
changelog.d/17096.misc Normal file
View file

@ -0,0 +1 @@
Use new receipts column to optimise receipt and push action SQL queries. Contributed by Nick @ Beeper (@fizzadar).

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.105.0) stable; urgency=medium
* New Synapse release 1.105.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 16 Apr 2024 15:53:23 +0100
matrix-synapse-py3 (1.105.0~rc1) stable; urgency=medium matrix-synapse-py3 (1.105.0~rc1) stable; urgency=medium
* New Synapse release 1.105.0rc1. * New Synapse release 1.105.0rc1.

View file

@ -102,6 +102,8 @@ experimental_features:
msc3391_enabled: true msc3391_enabled: true
# Filtering /messages by relation type. # Filtering /messages by relation type.
msc3874_enabled: true msc3874_enabled: true
# no UIA for x-signing upload for the first time
msc3967_enabled: true
server_notices: server_notices:
system_mxid_localpart: _server system_mxid_localpart: _server

View file

@ -96,7 +96,7 @@ module-name = "synapse.synapse_rust"
[tool.poetry] [tool.poetry]
name = "matrix-synapse" name = "matrix-synapse"
version = "1.105.0rc1" version = "1.105.0"
description = "Homeserver for the Matrix decentralised comms protocol" description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"] authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "AGPL-3.0-or-later" license = "AGPL-3.0-or-later"

View file

@ -7,7 +7,7 @@ name = "synapse"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
rust-version = "1.65.0" rust-version = "1.66.0"
[lib] [lib]
name = "synapse" name = "synapse"
@ -23,6 +23,9 @@ name = "synapse.synapse_rust"
[dependencies] [dependencies]
anyhow = "1.0.63" anyhow = "1.0.63"
bytes = "1.6.0"
headers = "0.4.0"
http = "1.1.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
log = "0.4.17" log = "0.4.17"
pyo3 = { version = "0.20.0", features = [ pyo3 = { version = "0.20.0", features = [

60
rust/src/errors.rs Normal file
View file

@ -0,0 +1,60 @@
/*
* This file is licensed under the Affero General Public License (AGPL) version 3.
*
* Copyright (C) 2024 New Vector, Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* See the GNU Affero General Public License for more details:
* <https://www.gnu.org/licenses/agpl-3.0.html>.
*/
#![allow(clippy::new_ret_no_self)]
use std::collections::HashMap;
use http::{HeaderMap, StatusCode};
use pyo3::{exceptions::PyValueError, import_exception};
import_exception!(synapse.api.errors, SynapseError);
impl SynapseError {
pub fn new(
code: StatusCode,
message: String,
errcode: &'static str,
additional_fields: Option<HashMap<String, String>>,
headers: Option<HeaderMap>,
) -> pyo3::PyErr {
// Transform the HeaderMap into a HashMap<String, String>
let headers = if let Some(headers) = headers {
let mut map = HashMap::with_capacity(headers.len());
for (key, value) in headers.iter() {
let Ok(value) = value.to_str() else {
// This should never happen, but we don't want to panic in case it does
return PyValueError::new_err(
"Could not construct SynapseError: header value is not valid ASCII",
);
};
map.insert(key.as_str().to_owned(), value.to_owned());
}
Some(map)
} else {
None
};
SynapseError::new_err((code.as_u16(), message, errcode, additional_fields, headers))
}
}
import_exception!(synapse.api.errors, NotFoundError);
impl NotFoundError {
pub fn new() -> pyo3::PyErr {
NotFoundError::new_err(())
}
}

165
rust/src/http.rs Normal file
View file

@ -0,0 +1,165 @@
/*
* This file is licensed under the Affero General Public License (AGPL) version 3.
*
* Copyright (C) 2024 New Vector, Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* See the GNU Affero General Public License for more details:
* <https://www.gnu.org/licenses/agpl-3.0.html>.
*/
use bytes::{Buf, BufMut, Bytes, BytesMut};
use headers::{Header, HeaderMapExt};
use http::{HeaderName, HeaderValue, Method, Request, Response, StatusCode, Uri};
use pyo3::{
exceptions::PyValueError,
types::{PyBytes, PySequence, PyTuple},
PyAny, PyResult,
};
use crate::errors::SynapseError;
/// Read a file-like Python object by chunks
///
/// # Errors
///
/// Returns an error if calling the `read` on the Python object failed
fn read_io_body(body: &PyAny, chunk_size: usize) -> PyResult<Bytes> {
let mut buf = BytesMut::new();
loop {
let bytes: &PyBytes = body.call_method1("read", (chunk_size,))?.downcast()?;
if bytes.as_bytes().is_empty() {
return Ok(buf.into());
}
buf.put(bytes.as_bytes());
}
}
/// Transform a Twisted `IRequest` to an [`http::Request`]
///
/// It uses the following members of `IRequest`:
/// - `content`, which is expected to be a file-like object with a `read` method
/// - `uri`, which is expected to be a valid URI as `bytes`
/// - `method`, which is expected to be a valid HTTP method as `bytes`
/// - `requestHeaders`, which is expected to have a `getAllRawHeaders` method
///
/// # Errors
///
/// Returns an error if the Python object doesn't properly implement `IRequest`
pub fn http_request_from_twisted(request: &PyAny) -> PyResult<Request<Bytes>> {
let content = request.getattr("content")?;
let body = read_io_body(content, 4096)?;
let mut req = Request::new(body);
let uri: &PyBytes = request.getattr("uri")?.downcast()?;
*req.uri_mut() =
Uri::try_from(uri.as_bytes()).map_err(|_| PyValueError::new_err("invalid uri"))?;
let method: &PyBytes = request.getattr("method")?.downcast()?;
*req.method_mut() = Method::from_bytes(method.as_bytes())
.map_err(|_| PyValueError::new_err("invalid method"))?;
let headers_iter = request
.getattr("requestHeaders")?
.call_method0("getAllRawHeaders")?
.iter()?;
for header in headers_iter {
let header = header?;
let header: &PyTuple = header.downcast()?;
let name: &PyBytes = header.get_item(0)?.downcast()?;
let name = HeaderName::from_bytes(name.as_bytes())
.map_err(|_| PyValueError::new_err("invalid header name"))?;
let values: &PySequence = header.get_item(1)?.downcast()?;
for index in 0..values.len()? {
let value: &PyBytes = values.get_item(index)?.downcast()?;
let value = HeaderValue::from_bytes(value.as_bytes())
.map_err(|_| PyValueError::new_err("invalid header value"))?;
req.headers_mut().append(name.clone(), value);
}
}
Ok(req)
}
/// Send an [`http::Response`] through a Twisted `IRequest`
///
/// It uses the following members of `IRequest`:
///
/// - `responseHeaders`, which is expected to have a `addRawHeader(bytes, bytes)` method
/// - `setResponseCode(int)` method
/// - `write(bytes)` method
/// - `finish()` method
///
/// # Errors
///
/// Returns an error if the Python object doesn't properly implement `IRequest`
pub fn http_response_to_twisted<B>(request: &PyAny, response: Response<B>) -> PyResult<()>
where
B: Buf,
{
let (parts, mut body) = response.into_parts();
request.call_method1("setResponseCode", (parts.status.as_u16(),))?;
let response_headers = request.getattr("responseHeaders")?;
for (name, value) in parts.headers.iter() {
response_headers.call_method1("addRawHeader", (name.as_str(), value.as_bytes()))?;
}
while body.remaining() != 0 {
let chunk = body.chunk();
request.call_method1("write", (chunk,))?;
body.advance(chunk.len());
}
request.call_method0("finish")?;
Ok(())
}
/// An extension trait for [`HeaderMap`] that provides typed access to headers, and throws the
/// right python exceptions when the header is missing or fails to parse.
///
/// [`HeaderMap`]: headers::HeaderMap
pub trait HeaderMapPyExt: HeaderMapExt {
/// Get a header from the map, returning an error if it is missing or invalid.
fn typed_get_required<H>(&self) -> PyResult<H>
where
H: Header,
{
self.typed_get_optional::<H>()?.ok_or_else(|| {
SynapseError::new(
StatusCode::BAD_REQUEST,
format!("Missing required header: {}", H::name()),
"M_MISSING_PARAM",
None,
None,
)
})
}
/// Get a header from the map, returning `None` if it is missing and an error if it is invalid.
fn typed_get_optional<H>(&self) -> PyResult<Option<H>>
where
H: Header,
{
self.typed_try_get::<H>().map_err(|_| {
SynapseError::new(
StatusCode::BAD_REQUEST,
format!("Invalid header: {}", H::name()),
"M_INVALID_PARAM",
None,
None,
)
})
}
}
impl<T: HeaderMapExt> HeaderMapPyExt for T {}

View file

@ -3,7 +3,9 @@ use pyo3::prelude::*;
use pyo3_log::ResetHandle; use pyo3_log::ResetHandle;
pub mod acl; pub mod acl;
pub mod errors;
pub mod events; pub mod events;
pub mod http;
pub mod push; pub mod push;
lazy_static! { lazy_static! {

View file

@ -214,7 +214,7 @@ fi
extra_test_args=() extra_test_args=()
test_packages="./tests/csapi ./tests ./tests/msc3874 ./tests/msc3890 ./tests/msc3391 ./tests/msc3930 ./tests/msc3902" test_packages="./tests/csapi ./tests ./tests/msc3874 ./tests/msc3890 ./tests/msc3391 ./tests/msc3930 ./tests/msc3902 ./tests/msc3967"
# Enable dirty runs, so tests will reuse the same container where possible. # Enable dirty runs, so tests will reuse the same container where possible.
# This significantly speeds up tests, but increases the possibility of test pollution. # This significantly speeds up tests, but increases the possibility of test pollution.

View file

@ -411,3 +411,14 @@ class ExperimentalConfig(Config):
self.msc4069_profile_inhibit_propagation = experimental.get( self.msc4069_profile_inhibit_propagation = experimental.get(
"msc4069_profile_inhibit_propagation", False "msc4069_profile_inhibit_propagation", False
) )
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
self.msc4108_delegation_endpoint: Optional[str] = experimental.get(
"msc4108_delegation_endpoint", None
)
if self.msc4108_delegation_endpoint is not None and not self.msc3861.enabled:
raise ConfigError(
"MSC4108 requires MSC3861 to be enabled",
("experimental", "msc4108_delegation_endpoint"),
)

View file

@ -1476,6 +1476,42 @@ class E2eKeysHandler:
else: else:
return exists, self.clock.time_msec() < ts_replacable_without_uia_before return exists, self.clock.time_msec() < ts_replacable_without_uia_before
async def has_different_keys(self, user_id: str, body: JsonDict) -> bool:
"""
Check if a key provided in `body` differs from the same key stored in the DB. Returns
true on the first difference. If a key exists in `body` but does not exist in the DB,
returns True. If `body` has no keys, this always returns False.
Note by 'key' we mean Matrix key rather than JSON key.
The purpose of this function is to detect whether or not we need to apply UIA checks.
We must apply UIA checks if any key in the database is being overwritten. If a key is
being inserted for the first time, or if the key exactly matches what is in the database,
then no UIA check needs to be performed.
Args:
user_id: The user who sent the `body`.
body: The JSON request body from POST /keys/device_signing/upload
Returns:
True if any key in `body` has a different value in the database.
"""
# Ensure that each key provided in the request body exactly matches the one we have stored.
# The first time we see the DB having a different key to the matching request key, bail.
# Note: we do not care if the DB has a key which the request does not specify, as we only
# care about *replacements* or *insertions* (i.e UPSERT)
req_body_key_to_db_key = {
"master_key": "master",
"self_signing_key": "self_signing",
"user_signing_key": "user_signing",
}
for req_body_key, db_key in req_body_key_to_db_key.items():
if req_body_key in body:
existing_key = await self.store.get_e2e_cross_signing_key(
user_id, db_key
)
if existing_key != body[req_body_key]:
return True
return False
def _check_cross_signing_key( def _check_cross_signing_key(
key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None

View file

@ -262,7 +262,8 @@ class _ProxyResponseBody(protocol.Protocol):
self._request.finish() self._request.finish()
else: else:
# Abort the underlying request since our remote request also failed. # Abort the underlying request since our remote request also failed.
self._request.transport.abortConnection() if self._request.channel:
self._request.channel.forceAbortClient()
class ProxySite(Site): class ProxySite(Site):

View file

@ -153,9 +153,9 @@ def return_json_error(
# Only respond with an error response if we haven't already started writing, # Only respond with an error response if we haven't already started writing,
# otherwise lets just kill the connection # otherwise lets just kill the connection
if request.startedWriting: if request.startedWriting:
if request.transport: if request.channel:
try: try:
request.transport.abortConnection() request.channel.forceAbortClient()
except Exception: except Exception:
# abortConnection throws if the connection is already closed # abortConnection throws if the connection is already closed
pass pass
@ -909,7 +909,18 @@ def set_cors_headers(request: "SynapseRequest") -> None:
request.setHeader( request.setHeader(
b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
) )
if request.experimental_cors_msc3886: if request.path is not None and request.path.startswith(
b"/_matrix/client/unstable/org.matrix.msc4108/rendezvous"
):
request.setHeader(
b"Access-Control-Allow-Headers",
b"Content-Type, If-Match, If-None-Match",
)
request.setHeader(
b"Access-Control-Expose-Headers",
b"Synapse-Trace-Id, Server, ETag",
)
elif request.experimental_cors_msc3886:
request.setHeader( request.setHeader(
b"Access-Control-Allow-Headers", b"Access-Control-Allow-Headers",
b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match",

View file

@ -19,7 +19,8 @@
# #
# #
""" This module contains base REST classes for constructing REST servlets. """ """This module contains base REST classes for constructing REST servlets."""
import enum import enum
import logging import logging
from http import HTTPStatus from http import HTTPStatus
@ -65,17 +66,49 @@ def parse_integer(request: Request, name: str, default: int) -> int: ...
@overload @overload
def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: ... def parse_integer(
request: Request, name: str, *, default: int, negative: bool
) -> int: ...
@overload @overload
def parse_integer( def parse_integer(
request: Request, name: str, default: Optional[int] = None, required: bool = False request: Request, name: str, *, default: int, negative: bool = False
) -> int: ...
@overload
def parse_integer(
request: Request, name: str, *, required: Literal[True], negative: bool = False
) -> int: ...
@overload
def parse_integer(
request: Request, name: str, *, default: Literal[None], negative: bool = False
) -> None: ...
@overload
def parse_integer(request: Request, name: str, *, negative: bool) -> Optional[int]: ...
@overload
def parse_integer(
request: Request,
name: str,
default: Optional[int] = None,
required: bool = False,
negative: bool = False,
) -> Optional[int]: ... ) -> Optional[int]: ...
def parse_integer( def parse_integer(
request: Request, name: str, default: Optional[int] = None, required: bool = False request: Request,
name: str,
default: Optional[int] = None,
required: bool = False,
negative: bool = False,
) -> Optional[int]: ) -> Optional[int]:
"""Parse an integer parameter from the request string """Parse an integer parameter from the request string
@ -85,16 +118,17 @@ def parse_integer(
default: value to use if the parameter is absent, defaults to None. default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the parameter is absent, required: whether to raise a 400 SynapseError if the parameter is absent,
defaults to False. defaults to False.
negative: whether to allow negative integers, defaults to True.
Returns: Returns:
An int value or the default. An int value or the default.
Raises: Raises:
SynapseError: if the parameter is absent and required, or if the SynapseError: if the parameter is absent and required, if the
parameter is present and not an integer. parameter is present and not an integer, or if the
parameter is illegitimate negative.
""" """
args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
return parse_integer_from_args(args, name, default, required) return parse_integer_from_args(args, name, default, required, negative)
@overload @overload
@ -120,6 +154,7 @@ def parse_integer_from_args(
name: str, name: str,
default: Optional[int] = None, default: Optional[int] = None,
required: bool = False, required: bool = False,
negative: bool = False,
) -> Optional[int]: ... ) -> Optional[int]: ...
@ -128,6 +163,7 @@ def parse_integer_from_args(
name: str, name: str,
default: Optional[int] = None, default: Optional[int] = None,
required: bool = False, required: bool = False,
negative: bool = True,
) -> Optional[int]: ) -> Optional[int]:
"""Parse an integer parameter from the request string """Parse an integer parameter from the request string
@ -137,33 +173,37 @@ def parse_integer_from_args(
default: value to use if the parameter is absent, defaults to None. default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the parameter is absent, required: whether to raise a 400 SynapseError if the parameter is absent,
defaults to False. defaults to False.
negative: whether to allow negative integers, defaults to True.
Returns: Returns:
An int value or the default. An int value or the default.
Raises: Raises:
SynapseError: if the parameter is absent and required, or if the SynapseError: if the parameter is absent and required, if the
parameter is present and not an integer. parameter is present and not an integer, or if the
parameter is illegitimate negative.
""" """
name_bytes = name.encode("ascii") name_bytes = name.encode("ascii")
if name_bytes in args: if name_bytes not in args:
try: if not required:
return int(args[name_bytes][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
)
else:
return default return default
message = f"Missing required integer query parameter {name}"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
try:
integer = int(args[name_bytes][0])
except Exception:
message = f"Query parameter {name} must be an integer"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
if not negative and integer < 0:
message = f"Query parameter {name} must be a positive integer."
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
return integer
@overload @overload
def parse_boolean(request: Request, name: str, default: bool) -> bool: ... def parse_boolean(request: Request, name: str, default: bool) -> bool: ...

View file

@ -150,7 +150,8 @@ class SynapseRequest(Request):
self.get_method(), self.get_method(),
self.get_redacted_uri(), self.get_redacted_uri(),
) )
self.transport.abortConnection() if self.channel:
self.channel.forceAbortClient()
return return
super().handleContentChunk(data) super().handleContentChunk(data)

View file

@ -23,7 +23,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.federation.transport.server import Authenticator from synapse.federation.transport.server import Authenticator
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -61,22 +61,8 @@ class ListDestinationsRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request) await assert_requester_is_admin(self._auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
destination = parse_string(request, "destination") destination = parse_string(request, "destination")
@ -195,22 +181,8 @@ class DestinationMembershipRestServlet(RestServlet):
if not await self._store.is_destination_known(destination): if not await self._store.is_destination_known(destination):
raise NotFoundError("Unknown destination") raise NotFoundError("Unknown destination")
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)

View file

@ -311,29 +311,17 @@ class DeleteMediaByDateSize(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True) before_ts = parse_integer(request, "before_ts", required=True, negative=False)
size_gt = parse_integer(request, "size_gt", default=0) size_gt = parse_integer(request, "size_gt", default=0, negative=False)
keep_profiles = parse_boolean(request, "keep_profiles", default=True) keep_profiles = parse_boolean(request, "keep_profiles", default=True)
if before_ts < 0: if before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter before_ts must be a positive integer.",
errcode=Codes.INVALID_PARAM,
)
elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"Query parameter before_ts you provided is from the year 1970. " "Query parameter before_ts you provided is from the year 1970. "
+ "Double check that you are providing a timestamp in milliseconds.", + "Double check that you are providing a timestamp in milliseconds.",
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
if size_gt < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter size_gt must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
# This check is useless, we keep it for the legacy endpoint only. # This check is useless, we keep it for the legacy endpoint only.
if server_name is not None and self.server_name != server_name: if server_name is not None and self.server_name != server_name:
@ -389,22 +377,8 @@ class UserMediaRestServlet(RestServlet):
if user is None: if user is None:
raise NotFoundError("Unknown user") raise NotFoundError("Unknown user")
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
# If neither `order_by` nor `dir` is set, set the default order # If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility. # to newest media is on top for backward compatibility.
@ -447,22 +421,8 @@ class UserMediaRestServlet(RestServlet):
if user is None: if user is None:
raise NotFoundError("Unknown user") raise NotFoundError("Unknown user")
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
# If neither `order_by` nor `dir` is set, set the default order # If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility. # to newest media is on top for backward compatibility.

View file

@ -63,38 +63,12 @@ class UserMediaStatisticsRestServlet(RestServlet):
), ),
) )
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
if start < 0: limit = parse_integer(request, "limit", default=100, negative=False)
raise SynapseError( from_ts = parse_integer(request, "from_ts", default=0, negative=False)
HTTPStatus.BAD_REQUEST, until_ts = parse_integer(request, "until_ts", negative=False)
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
limit = parse_integer(request, "limit", default=100)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
from_ts = parse_integer(request, "from_ts", default=0)
if from_ts < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from_ts must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
until_ts = parse_integer(request, "until_ts")
if until_ts is not None: if until_ts is not None:
if until_ts < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter until_ts must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if until_ts <= from_ts: if until_ts <= from_ts:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,

View file

@ -90,22 +90,8 @@ class UsersRestServletV2(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
user_id = parse_string(request, "user_id") user_id = parse_string(request, "user_id")
name = parse_string(request, "name", encoding="utf-8") name = parse_string(request, "name", encoding="utf-8")

View file

@ -409,7 +409,18 @@ class SigningKeyUploadServlet(RestServlet):
# But first-time setup is fine # But first-time setup is fine
elif self.hs.config.experimental.msc3967_enabled: elif self.hs.config.experimental.msc3967_enabled:
# If we already have a master key then cross signing is set up and we require UIA to reset # MSC3967 allows this endpoint to 200 OK for idempotency. Resending exactly the same
# keys should just 200 OK without doing a UIA prompt.
keys_are_different = await self.e2e_keys_handler.has_different_keys(
user_id, body
)
if not keys_are_different:
# FIXME: we do not fallthrough to upload_signing_keys_for_user because confusingly
# if we do, we 500 as it looks like it tries to INSERT the same key twice, causing a
# unique key constraint violation. This sounds like a bug?
return 200, {}
# the keys are different, is x-signing set up? If no, then the keys don't exist which is
# why they are different. If yes, then we need to UIA to change them.
if is_cross_signing_setup: if is_cross_signing_setup:
await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
@ -420,7 +431,6 @@ class SigningKeyUploadServlet(RestServlet):
can_skip_ui_auth=False, can_skip_ui_auth=False,
) )
# Otherwise we don't require UIA since we are setting up cross signing for first time # Otherwise we don't require UIA since we are setting up cross signing for first time
else: else:
# Previous behaviour is to always require UIA but allow it to be skipped # Previous behaviour is to always require UIA but allow it to be skipped
await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(

View file

@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3. # This file is licensed under the Affero General Public License (AGPL) version 3.
# #
# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd # Copyright (C) 2023-2024 New Vector, Ltd
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as # it under the terms of the GNU Affero General Public License as
@ -34,7 +34,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RendezvousServlet(RestServlet): class MSC3886RendezvousServlet(RestServlet):
""" """
This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886)
simple client rendezvous capability that is used by the "Sign in with QR" functionality. simple client rendezvous capability that is used by the "Sign in with QR" functionality.
@ -76,6 +76,30 @@ class RendezvousServlet(RestServlet):
# PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target. # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target.
class MSC4108DelegationRendezvousServlet(RestServlet):
PATTERNS = client_patterns(
"/org.matrix.msc4108/rendezvous$", releases=[], v1=False, unstable=True
)
def __init__(self, hs: "HomeServer"):
super().__init__()
redirection_target: Optional[str] = (
hs.config.experimental.msc4108_delegation_endpoint
)
assert (
redirection_target is not None
), "Servlet is only registered if there is a delegation target"
self.endpoint = redirection_target.encode("utf-8")
async def on_POST(self, request: SynapseRequest) -> None:
respond_with_redirect(
request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc3886_endpoint is not None: if hs.config.experimental.msc3886_endpoint is not None:
RendezvousServlet(hs).register(http_server) MSC3886RendezvousServlet(hs).register(http_server)
if hs.config.experimental.msc4108_delegation_endpoint is not None:
MSC4108DelegationRendezvousServlet(hs).register(http_server)

View file

@ -499,7 +499,7 @@ class PublicRoomListRestServlet(RestServlet):
if server: if server:
raise e raise e
limit: Optional[int] = parse_integer(request, "limit", 0) limit: Optional[int] = parse_integer(request, "limit", 0, negative=False)
since_token = parse_string(request, "since") since_token = parse_string(request, "since")
if limit == 0: if limit == 0:

View file

@ -140,6 +140,9 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc4069": self.config.experimental.msc4069_profile_inhibit_propagation, "org.matrix.msc4069": self.config.experimental.msc4069_profile_inhibit_propagation,
# Allows clients to handle push for encrypted events. # Allows clients to handle push for encrypted events.
"org.matrix.msc4028": self.config.experimental.msc4028_push_encrypted_events, "org.matrix.msc4028": self.config.experimental.msc4028_push_encrypted_events,
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
"org.matrix.msc4108": self.config.experimental.msc4108_delegation_endpoint
is not None,
}, },
}, },
) )

View file

@ -72,9 +72,6 @@ class PreviewUrlResource(RestServlet):
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url", required=True) url = parse_string(request, "url", required=True)
ts = parse_integer(request, "ts") ts = parse_integer(request, "ts", default=self.clock.time_msec())
if ts is None:
ts = self.clock.time_msec()
og = await self.url_previewer.preview(url, requester.user, ts) og = await self.url_previewer.preview(url, requester.user, ts)
respond_with_json_bytes(request, 200, og, send_cors=True) respond_with_json_bytes(request, 200, og, send_cors=True)

View file

@ -385,7 +385,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
WITH all_receipts AS ( WITH all_receipts AS (
SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
FROM receipts_linearized FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE WHERE
{receipt_types_clause} {receipt_types_clause}
AND user_id = ? AND user_id = ?
@ -621,13 +620,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
SELECT notif_count, COALESCE(unread_count, 0), thread_id SELECT notif_count, COALESCE(unread_count, 0), thread_id
FROM event_push_summary FROM event_push_summary
LEFT JOIN ( LEFT JOIN (
SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering SELECT thread_id, MAX(event_stream_ordering) AS threaded_receipt_stream_ordering
FROM receipts_linearized FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE WHERE
user_id = ? user_id = ?
AND room_id = ? AND room_id = ?
AND stream_ordering > ? AND event_stream_ordering > ?
AND {receipt_types_clause} AND {receipt_types_clause}
GROUP BY thread_id GROUP BY thread_id
) AS receipts USING (thread_id) ) AS receipts USING (thread_id)
@ -659,13 +657,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
sql = f""" sql = f"""
SELECT COUNT(*), thread_id FROM event_push_actions SELECT COUNT(*), thread_id FROM event_push_actions
LEFT JOIN ( LEFT JOIN (
SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering SELECT thread_id, MAX(event_stream_ordering) AS threaded_receipt_stream_ordering
FROM receipts_linearized FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE WHERE
user_id = ? user_id = ?
AND room_id = ? AND room_id = ?
AND stream_ordering > ? AND event_stream_ordering > ?
AND {receipt_types_clause} AND {receipt_types_clause}
GROUP BY thread_id GROUP BY thread_id
) AS receipts USING (thread_id) ) AS receipts USING (thread_id)
@ -738,13 +735,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
thread_id thread_id
FROM event_push_actions FROM event_push_actions
LEFT JOIN ( LEFT JOIN (
SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering SELECT thread_id, MAX(event_stream_ordering) AS threaded_receipt_stream_ordering
FROM receipts_linearized FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE WHERE
user_id = ? user_id = ?
AND room_id = ? AND room_id = ?
AND stream_ordering > ? AND event_stream_ordering > ?
AND {receipt_types_clause} AND {receipt_types_clause}
GROUP BY thread_id GROUP BY thread_id
) AS receipts USING (thread_id) ) AS receipts USING (thread_id)
@ -910,9 +906,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# given this function generally gets called with only one room and # given this function generally gets called with only one room and
# thread ID. # thread ID.
sql = f""" sql = f"""
SELECT room_id, thread_id, MAX(stream_ordering) SELECT room_id, thread_id, MAX(event_stream_ordering)
FROM receipts_linearized FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause} WHERE {receipt_types_clause}
AND {thread_ids_clause} AND {thread_ids_clause}
AND {room_ids_clause} AND {room_ids_clause}
@ -1442,9 +1437,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) )
sql = """ sql = """
SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, r.event_stream_ordering
FROM receipts_linearized AS r FROM receipts_linearized AS r
INNER JOIN events AS e USING (event_id)
WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
ORDER BY r.stream_id ASC ORDER BY r.stream_id ASC
LIMIT ? LIMIT ?

View file

@ -178,14 +178,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
sql = f""" sql = f"""
SELECT event_id, stream_ordering SELECT event_id, event_stream_ordering
FROM receipts_linearized FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {clause} WHERE {clause}
AND user_id = ? AND user_id = ?
AND room_id = ? AND room_id = ?
AND thread_id IS NULL AND thread_id IS NULL
ORDER BY stream_ordering DESC ORDER BY event_stream_ordering DESC
LIMIT 1 LIMIT 1
""" """
@ -735,10 +734,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
thread_clause = "r.thread_id = ?" thread_clause = "r.thread_id = ?"
thread_args = (thread_id,) thread_args = (thread_id,)
# If the receipt doesn't have a stream ordering it is because we
# don't have the associated event, and so must be a remote receipt.
# Hence it's safe to just allow new receipts to clobber it.
sql = f""" sql = f"""
SELECT stream_ordering, event_id FROM events SELECT r.event_stream_ordering, r.event_id FROM receipts_linearized AS r
INNER JOIN receipts_linearized AS r USING (event_id, room_id) WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?
WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause} AND r.event_stream_ordering IS NOT NULL AND {thread_clause}
""" """
txn.execute( txn.execute(
sql, sql,

View file

@ -1101,6 +1101,56 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_has_different_keys(self) -> None:
"""check that has_different_keys returns True when the keys provided are different to what
is in the database."""
local_user = "@boris:" + self.hs.hostname
keys1 = {
"master_key": {
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
"user_id": local_user,
"usage": ["master"],
"keys": {
"ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
},
}
}
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
is_different = self.get_success(
self.handler.has_different_keys(
local_user,
{
"master_key": keys1["master_key"],
},
)
)
self.assertEqual(is_different, False)
# change the usage => different keys
keys1["master_key"]["usage"] = ["develop"]
is_different = self.get_success(
self.handler.has_different_keys(
local_user,
{
"master_key": keys1["master_key"],
},
)
)
self.assertEqual(is_different, True)
keys1["master_key"]["usage"] = ["master"] # reset
# change the key => different keys
keys1["master_key"]["keys"] = {
"ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unIc0rncs": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unIc0rncs"
}
is_different = self.get_success(
self.handler.has_different_keys(
local_user,
{
"master_key": keys1["master_key"],
},
)
)
self.assertEqual(is_different, True)
def test_query_devices_remote_sync(self) -> None: def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with, """Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys but haven't yet fetched the keys for, returns the cross signing keys

View file

@ -277,7 +277,8 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual( self.assertEqual(
"Missing integer query parameter 'before_ts'", channel.json_body["error"] "Missing required integer query parameter before_ts",
channel.json_body["error"],
) )
def test_invalid_parameter(self) -> None: def test_invalid_parameter(self) -> None:
@ -320,7 +321,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual( self.assertEqual(
"Query parameter size_gt must be a string representing a positive integer.", "Query parameter size_gt must be a positive integer.",
channel.json_body["error"], channel.json_body["error"],
) )

View file

@ -27,8 +27,10 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import HAS_AUTHLIB
endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" msc3886_endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
msc4108_endpoint = "/_matrix/client/unstable/org.matrix.msc4108/rendezvous"
class RendezvousServletTestCase(unittest.HomeserverTestCase): class RendezvousServletTestCase(unittest.HomeserverTestCase):
@ -41,11 +43,35 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def test_disabled(self) -> None: def test_disabled(self) -> None:
channel = self.make_request("POST", endpoint, {}, access_token=None) channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
self.assertEqual(channel.code, 404)
channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
self.assertEqual(channel.code, 404) self.assertEqual(channel.code, 404)
@override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
def test_redirect(self) -> None: def test_msc3886_redirect(self) -> None:
channel = self.make_request("POST", endpoint, {}, access_token=None) channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
self.assertEqual(channel.code, 307) self.assertEqual(channel.code, 307)
self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"]) self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"])
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
@override_config(
{
"disable_registration": True,
"experimental_features": {
"msc4108_delegation_endpoint": "https://asd",
"msc3861": {
"enabled": True,
"issuer": "https://issuer",
"client_id": "client_id",
"client_auth_method": "client_secret_post",
"client_secret": "client_secret",
"admin_token": "admin_token_value",
},
},
}
)
def test_msc4108_delegation(self) -> None:
channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
self.assertEqual(channel.code, 307)
self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"])