diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 19bca94170..12fa1425b2 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -46,10 +46,10 @@ from synapse.api.errors import Codes, cs_error from synapse.http.server import finish_request, respond_with_json from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable +from synapse.util import Clock from synapse.util.stringutils import is_ascii if TYPE_CHECKING: - from synapse.media.media_storage import MultipartResponder from synapse.storage.databases.main.media_repository import LocalMedia @@ -275,8 +275,9 @@ def _can_encode_filename_as_token(x: str) -> bool: async def respond_with_multipart_responder( + clock: Clock, request: SynapseRequest, - responder: "Optional[MultipartResponder]", + responder: "Optional[Responder]", media_info: "LocalMedia", ) -> None: """ @@ -299,15 +300,22 @@ async def respond_with_multipart_responder( ) return + from synapse.media.media_storage import MultipartFileConsumer + + multipart_consumer = MultipartFileConsumer( + clock, request, media_info.media_type, {} + ) + logger.debug("Responding to media request with responder %s", responder) if media_info.media_length is not None: request.setHeader(b"Content-Length", b"%d" % (media_info.media_length,)) request.setHeader( - b"Content-Type", b"multipart/mixed; boundary=%s" % responder.boundary + b"Content-Type", + b"multipart/mixed; boundary=%s" % multipart_consumer.boundary, ) try: - await responder.write_to_consumer(request) + await responder.write_to_consumer(multipart_consumer) except Exception as e: # The majority of the time this will be due to the client having gone # away. Unfortunately, Twisted simply throws a generic exception at us diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index c335e518a0..e9725c6b14 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -58,7 +58,7 @@ from synapse.media._base import ( respond_with_responder, ) from synapse.media.filepath import MediaFilePaths -from synapse.media.media_storage import MediaStorage, MultipartResponder +from synapse.media.media_storage import MediaStorage from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer @@ -467,8 +467,9 @@ class MediaRepository: ) if federation: # this really should be a Multipart responder but just in case - assert isinstance(responder, MultipartResponder) - await respond_with_multipart_responder(request, responder, media_info) + await respond_with_multipart_responder( + self.clock, request, responder, media_info + ) else: await respond_with_responder( request, responder, media_type, media_length, upload_name diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index 2f55d12b6b..baf947b873 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -39,19 +39,24 @@ from typing import ( Tuple, Type, Union, + cast, ) from uuid import uuid4 import attr from zope.interface import implementer -from twisted.internet import defer, interfaces +from twisted.internet import interfaces from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.api.errors import NotFoundError -from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.logging.context import ( + defer_to_thread, + make_deferred_yieldable, + run_in_background, +) from synapse.logging.opentracing import start_active_span, trace, trace_with_opname from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer @@ -217,14 +222,7 @@ class MediaStorage: local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): logger.debug("responding with local file %s", local_path) - if federation: - assert media_info is not None - boundary = uuid4().hex.encode("ascii") - return MultipartResponder( - open(local_path, "rb"), media_info, boundary - ) - else: - return FileResponder(open(local_path, "rb")) + return FileResponder(open(local_path, "rb")) logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: @@ -364,38 +362,6 @@ class FileResponder(Responder): self.open_file.close() -class MultipartResponder(Responder): - """Wraps an open file, formats the response according to MSC3916 and sends it to a - federation request. - - Args: - open_file: A file like object to be streamed to the client, - is closed when finished streaming. - media_info: metadata about the media item - boundary: bytes to use for the multipart response boundary - """ - - def __init__(self, open_file: IO, media_info: LocalMedia, boundary: bytes) -> None: - self.open_file = open_file - self.media_info = media_info - self.boundary = boundary - - def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - MultipartFileSender().beginFileTransfer( - self.open_file, consumer, self.media_info.media_type, {}, self.boundary - ) - ) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.open_file.close() - - class SpamMediaException(NotFoundError): """The media was blocked by a spam checker, so we simply 404 the request (in the same way as if it was quarantined). @@ -431,105 +397,153 @@ class ReadableFileWrapper: await self.clock.sleep(0) -@implementer(interfaces.IProducer) -class MultipartFileSender: - """ - A producer that sends the contents of a file to a federation request in the format - outlined in MSC3916 - a multipart/format-data response where the first field is a - JSON object and the second is the requested file. - - This is a slight re-writing of twisted.protocols.basic.FileSender to achieve the format - outlined above. - """ - - CHUNK_SIZE = 2**14 - - lastSent = "" - deferred: Optional[defer.Deferred] = None - - def beginFileTransfer( +@implementer(interfaces.IConsumer) +@implementer(interfaces.IPushProducer) +class MultipartFileConsumer: + def __init__( self, - file: IO, - consumer: IConsumer, + clock: Clock, + wrapped_consumer: interfaces.IConsumer, file_content_type: str, json_object: JsonDict, - boundary: bytes, - ) -> Deferred: - """ - Begin transferring a file - - Args: - file: The file object to read data from - consumer: The synapse request to write the data to - file_content_type: The content-type of the file - json_object: The JSON object to write to the first field of the response - boundary: bytes to be used as the multipart/form-data boundary - - Returns: A deferred whose callback will be invoked when the file has - been completely written to the consumer. The last byte written to the - consumer is passed to the callback. - """ - self.file: Optional[IO] = file - self.consumer = consumer + ) -> None: + self.clock = clock + self.wrapped_consumer = wrapped_consumer self.json_field = json_object self.json_field_written = False self.content_type_written = False self.file_content_type = file_content_type - self.boundary = boundary - self.deferred: Deferred = defer.Deferred() - self.consumer.registerProducer(self, False) - # while it's not entirely clear why this assignment is necessary, it mirrors - # the behavior in FileSender.beginFileTransfer and thus is preserved here - deferred = self.deferred - return deferred + self.boundary = uuid4().hex.encode("ascii") - def resumeProducing(self) -> None: - # write the first field, which will always be a json field + self.producer: Optional["interfaces.IProducer"] = None + self.streaming = Optional[None] + + self.paused = False + + def registerProducer( + self, producer: "interfaces.IProducer", streaming: bool + ) -> None: + """ + Register to receive data from a producer. + + This sets self to be a consumer for a producer. When this object runs + out of data (as when a send(2) call on a socket succeeds in moving the + last data from a userspace buffer into a kernelspace buffer), it will + ask the producer to resumeProducing(). + + For L{IPullProducer} providers, C{resumeProducing} will be called once + each time data is required. + + For L{IPushProducer} providers, C{pauseProducing} will be called + whenever the write buffer fills up and C{resumeProducing} will only be + called when it empties. The consumer will only call C{resumeProducing} + to balance a previous C{pauseProducing} call; the producer is assumed + to start in an un-paused state. + + @param streaming: C{True} if C{producer} provides L{IPushProducer}, + C{False} if C{producer} provides L{IPullProducer}. + + @raise RuntimeError: If a producer is already registered. + """ + self.producer = producer + self.streaming = streaming + + self.wrapped_consumer.registerProducer(self, True) + + run_in_background(self._resumeProducingRepeatedly) + + def unregisterProducer(self) -> None: + """ + Stop consuming data from a producer, without disconnecting. + """ + self.wrapped_consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF) + self.wrapped_consumer.unregisterProducer() + self.paused = True + + def write(self, data: bytes) -> None: + """ + The producer will write data by calling this method. + + The implementation must be non-blocking and perform whatever + buffering is necessary. If the producer has provided enough data + for now and it is a L{IPushProducer}, the consumer may call its + C{pauseProducing} method. + """ if not self.json_field_written: - self.consumer.write(CRLF + b"--" + self.boundary + CRLF) + self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF) content_type = Header(b"Content-Type", b"application/json") - self.consumer.write(bytes(content_type) + CRLF) + self.wrapped_consumer.write(bytes(content_type) + CRLF) json_field = json.dumps(self.json_field) json_bytes = json_field.encode("utf-8") - self.consumer.write(json_bytes) - self.consumer.write(CRLF + b"--" + self.boundary + CRLF) + self.wrapped_consumer.write(json_bytes) + self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF) self.json_field_written = True - chunk: Any = "" - if self.file: - # if we haven't written the content type yet, do so - if not self.content_type_written: - type = self.file_content_type.encode("utf-8") - content_type = Header(b"Content-Type", type) - self.consumer.write(bytes(content_type) + CRLF) - self.content_type_written = True + # if we haven't written the content type yet, do so + if not self.content_type_written: + type = self.file_content_type.encode("utf-8") + content_type = Header(b"Content-Type", type) + self.wrapped_consumer.write(bytes(content_type) + CRLF) + self.content_type_written = True - chunk = self.file.read(self.CHUNK_SIZE) - - if not chunk: - # we've reached the end of the file - self.consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF) - self.file = None - self.consumer.unregisterProducer() - - if self.deferred: - self.deferred.callback(self.lastSent) - self.deferred = None - return - - self.consumer.write(chunk) - self.lastSent = chunk[-1:] - - def pauseProducing(self) -> None: - pass + self.wrapped_consumer.write(data) def stopProducing(self) -> None: - if self.deferred: - self.deferred.errback(Exception("Consumer asked us to stop producing")) - self.deferred = None + """ + Stop producing data. + + This tells a producer that its consumer has died, so it must stop + producing data for good. + """ + assert self.producer is not None + + self.paused = True + self.producer.stopProducing() + + def pauseProducing(self) -> None: + """ + Pause producing data. + + Tells a producer that it has produced too much data to process for + the time being, and to stop until C{resumeProducing()} is called. + """ + assert self.producer is not None + + self.paused = True + + if self.streaming: + cast("interfaces.IPushProducer", self.producer).pauseProducing() + else: + self.paused = True + + def resumeProducing(self) -> None: + """ + Resume producing data. + + This tells a producer to re-add itself to the main loop and produce + more data for its consumer. + """ + assert self.producer is not None + + if self.streaming: + cast("interfaces.IPushProducer", self.producer).resumeProducing() + return + + run_in_background(self._resumeProducingRepeatedly) + + async def _resumeProducingRepeatedly(self) -> None: + assert self.producer is not None + assert not self.streaming + + producer = cast("interfaces.IPullProducer", self.producer) + + self.paused = False + while not self.paused: + producer.resumeProducing() + await self.clock.sleep(0) class Header: diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py index a2d50adf65..a71da3587c 100644 --- a/synapse/media/storage_provider.py +++ b/synapse/media/storage_provider.py @@ -24,7 +24,6 @@ import logging import os import shutil from typing import TYPE_CHECKING, Callable, Optional -from uuid import uuid4 from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background @@ -33,7 +32,7 @@ from synapse.util.async_helpers import maybe_awaitable from ..storage.databases.main.media_repository import LocalMedia from ._base import FileInfo, Responder -from .media_storage import FileResponder, MultipartResponder +from .media_storage import FileResponder logger = logging.getLogger(__name__) @@ -201,12 +200,6 @@ class FileStorageProviderBackend(StorageProvider): backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): - if federation: - assert media_info is not None - boundary = uuid4().hex.encode("ascii") - return MultipartResponder( - open(backup_fname, "rb"), media_info, boundary - ) return FileResponder(open(backup_fname, "rb")) return None